diff --git a/analysis_configurations/template_analysis/generate.py b/analysis_configurations/template_analysis/generate.py index ba1cda60..62a72389 100644 --- a/analysis_configurations/template_analysis/generate.py +++ b/analysis_configurations/template_analysis/generate.py @@ -42,6 +42,7 @@ def run(args): available_samples, available_eras, available_scopes, + DAG_dir=f"{args.output}/visualization", ) # create a CodeGenerator object generator = CodeGenerator( diff --git a/analysis_configurations/template_analysis/generate_friends.py b/analysis_configurations/template_analysis/generate_friends.py index 805e60a5..b5ec0200 100644 --- a/analysis_configurations/template_analysis/generate_friends.py +++ b/analysis_configurations/template_analysis/generate_friends.py @@ -58,6 +58,7 @@ def run(args): available_eras, available_scopes, args.quantities_map, + DAG_dir=f"{args.output}/visualization", ) # check if the config is of type FriendTreeConfiguration if not isinstance(code_generation_config, FriendTreeConfiguration): diff --git a/analysis_configurations/template_analysis/producers/event.py b/analysis_configurations/template_analysis/producers/event.py index 770f5b8b..6566fc79 100644 --- a/analysis_configurations/template_analysis/producers/event.py +++ b/analysis_configurations/template_analysis/producers/event.py @@ -79,6 +79,7 @@ output=None, scopes=["global"], vec_configs=["met_filters"], + is_filter=True, ) Lumi = Producer( diff --git a/analysis_configurations/template_analysis/producers/scalefactors.py b/analysis_configurations/template_analysis/producers/scalefactors.py index e6c55de0..8aa0a703 100644 --- a/analysis_configurations/template_analysis/producers/scalefactors.py +++ b/analysis_configurations/template_analysis/producers/scalefactors.py @@ -9,8 +9,8 @@ Muon_1_Reco_SF = Producer( name="Muon_1_Reco_SF", call="""physicsobject::muon::scalefactor::Reco( - {df}, - correctionManager, + {df}, + correctionManager, {output}, {input}, "{muon_sf_file}", @@ -24,8 +24,8 @@ Muon_1_ID_SF = Producer( name="Muon_1_ID_SF", call="""physicsobject::muon::scalefactor::Id( - {df}, - correctionManager, + {df}, + correctionManager, {output}, {input}, "{muon_sf_file}", @@ -39,8 +39,8 @@ Muon_1_Iso_SF = Producer( name="Muon_1_Iso_SF", call="""physicsobject::muon::scalefactor::Iso( - {df}, - correctionManager, + {df}, + correctionManager, {output}, {input}, "{muon_sf_file}", @@ -55,8 +55,8 @@ Muon_2_Reco_SF = Producer( name="Muon_2_Reco_SF", call="""physicsobject::muon::scalefactor::Reco( - {df}, - correctionManager, + {df}, + correctionManager, {output}, {input}, "{muon_sf_file}", @@ -70,8 +70,8 @@ Muon_2_ID_SF = Producer( name="Muon_2_ID_SF", call="""physicsobject::muon::scalefactor::Id( - {df}, - correctionManager, + {df}, + correctionManager, {output}, {input}, "{muon_sf_file}", @@ -85,8 +85,8 @@ Muon_2_Iso_SF = Producer( name="Muon_2_Iso_SF", call="""physicsobject::muon::scalefactor::Iso( - {df}, - correctionManager, + {df}, + correctionManager, {output}, {input}, "{muon_sf_file}", diff --git a/analysis_configurations/template_analysis/template_config.py b/analysis_configurations/template_analysis/template_config.py index 21495a68..733dcfe9 100644 --- a/analysis_configurations/template_analysis/template_config.py +++ b/analysis_configurations/template_analysis/template_config.py @@ -14,6 +14,7 @@ from code_generation.modifiers import EraModifier from code_generation.rules import RemoveProducer, AppendProducer from code_generation.systematics import SystematicShift +from code_generation.utility.generate_DAG import create_graph def build_config( @@ -24,6 +25,7 @@ def build_config( available_sample_types: List[str], available_eras: List[str], available_scopes: List[str], + DAG_dir: str = "", ): configuration = Configuration( era, @@ -260,4 +262,8 @@ def build_config( configuration.optimize() configuration.validate() configuration.report() - return configuration.expanded_configuration() + if DAG_dir: + nanoAOD_inputs = [n for n in dir(nanoAOD) if not n.startswith("__")] + create_graph(configuration, nanoAOD_inputs, DAG_dir, "CROWNelements") + configuration = configuration.expanded_configuration() + return configuration diff --git a/analysis_configurations/template_analysis/template_friend_config.py b/analysis_configurations/template_analysis/template_friend_config.py index 6de2c931..22fd732d 100644 --- a/analysis_configurations/template_analysis/template_friend_config.py +++ b/analysis_configurations/template_analysis/template_friend_config.py @@ -4,9 +4,11 @@ from .producers import pairselection as pairselection from .producers import genparticles as genparticles +from .quantities import nanoAOD as nanoAOD from .quantities import output as q from code_generation.friend_trees import FriendTreeConfiguration from code_generation.rules import RemoveProducer +from code_generation.utility.generate_DAG import create_graph def build_config( @@ -18,6 +20,7 @@ def build_config( available_eras: List[str], available_scopes: List[str], quantities_map: Union[str, None] = None, + DAG_dir: str = "", ): configuration = FriendTreeConfiguration( era, @@ -67,4 +70,8 @@ def build_config( configuration.optimize() configuration.validate() configuration.report() - return configuration.expanded_configuration() + if DAG_dir: + nanoAOD_inputs = [n for n in dir(nanoAOD) if not n.startswith("__")] + create_graph(configuration, nanoAOD_inputs, DAG_dir, "CROWNelements") + configuration = configuration.expanded_configuration() + return configuration diff --git a/analysis_configurations/template_analysis/template_multifriend_config.py b/analysis_configurations/template_analysis/template_multifriend_config.py index 23eb54e0..7329ae41 100644 --- a/analysis_configurations/template_analysis/template_multifriend_config.py +++ b/analysis_configurations/template_analysis/template_multifriend_config.py @@ -4,9 +4,11 @@ from .producers import pairselection as pairselection from .producers import genparticles as genparticles +from .quantities import nanoAOD as nanoAOD from .quantities import output as q from code_generation.friend_trees import FriendTreeConfiguration from code_generation.rules import RemoveProducer +from code_generation.utility.generate_DAG import create_graph def build_config( @@ -18,6 +20,7 @@ def build_config( available_eras: List[str], available_scopes: List[str], quantities_map: Union[str, None] = None, + DAG_dir: str = "", ): configuration = FriendTreeConfiguration( era, @@ -62,4 +65,8 @@ def build_config( configuration.optimize() configuration.validate() configuration.report() - return configuration.expanded_configuration() + if DAG_dir: + nanoAOD_inputs = [n for n in dir(nanoAOD) if not n.startswith("__")] + create_graph(configuration, nanoAOD_inputs, DAG_dir, "CROWNelements") + configuration = configuration.expanded_configuration() + return configuration diff --git a/code_generation/analysis_template.cxx b/code_generation/analysis_template.cxx index 990f2939..31b05dd6 100644 --- a/code_generation/analysis_template.cxx +++ b/code_generation/analysis_template.cxx @@ -156,73 +156,80 @@ int main(int argc, char *argv[]) { const std::string analysis = {ANALYSISTAG}; const std::string config = {CONFIGTAG}; const std::string era = {ERATAG}; - const std::string sample = {SAMPLETAG}; + const std::string sample_type = {SAMPLETAG}; const std::string commit_hash = {COMMITHASH}; bool setup_clean = {CROWN_IS_CLEAN}; const std::string analysis_commit_hash = {ANALYSIS_COMMITHASH}; bool analysis_setup_clean = {ANALYSIS_IS_CLEAN}; int scope_counter = 0; for (auto const &output : output_quantities) { - // output.first is the output file name - // output.second is the list of quantities + TFile outputfile(output.first.c_str(), "UPDATE"); - TTree quantities_meta = TTree("quantities", "quantities"); - for (auto const &quantity : output.second) { - quantities_meta.Branch(quantity.c_str(), &setup_clean); - } - quantities_meta.Write(); - TTree variations_meta = TTree("variations", "variations"); - for (auto const &variation : variations.at(output.first)) { - variations_meta.Branch(variation.c_str(), &setup_clean); - } - variations_meta.Write(); - TTree conditions_meta = TTree("conditions", "conditions"); - conditions_meta.Branch(analysis.c_str(), &setup_clean); - conditions_meta.Branch(config.c_str(), &setup_clean); - conditions_meta.Branch(era.c_str(), &setup_clean); - conditions_meta.Branch(sample.c_str(), &setup_clean); - conditions_meta.Write(); - TTree commit_meta = TTree("commit", "commit"); - commit_meta.Branch(commit_hash.c_str(), &setup_clean); - commit_meta.Fill(); - commit_meta.Write(); - TTree analysis_commit_meta = - TTree("analysis_commit", "analysis_commit"); - analysis_commit_meta.Branch(analysis_commit_hash.c_str(), - &analysis_setup_clean); - analysis_commit_meta.Fill(); - analysis_commit_meta.Write(); + + // ----------------------------- + // Unified metadata object + // ----------------------------- + nlohmann::json j; + + j["metadata"] = { + {"analysis", analysis}, + {"config", config}, + {"era", era}, + {"sample_type", sample_type}, + {"commit", commit_hash}, + {"analysis_commit", analysis_commit_hash}, + {"is_clean", setup_clean}, + {"analysis_is_clean", analysis_setup_clean} + }; + + j["content"] = { + {"quantities", output.second}, + {"variations", variations.at(output.first)} + }; + + TObjString json(j.dump().c_str()); + outputfile.WriteObject(&json, "metadata"); + + // ----------------------------- + // Cutflow histogram + // ----------------------------- if (nevents != 0) { + TH1D cutflow; cutflow.SetName("cutflow"); cutflow.SetTitle("cutflow"); - // iterate through the cutflow vector and fill the histogram with - // the .GetPass() values - if (scope_counter >= cutReports.size()) { + + if (cutReports.size() < scope_counter || cutReports.empty()) { Logger::get("main")->critical( - "Cutflow vector is too small, this should not happen"); + "cutReports vector is too small, this should not happen"); return 1; } + for (auto cut = cutReports[scope_counter].begin(); - cut != cutReports[scope_counter].end(); cut++) { - cutflow.SetBinContent( - std::distance(cutReports[scope_counter].begin(), cut) + 1, - cut->GetPass()); - cutflow.GetXaxis()->SetBinLabel( - std::distance(cutReports[scope_counter].begin(), cut) + 1, - cut->GetName().c_str()); + cut != cutReports[scope_counter].end(); cut++) { + + int bin = std::distance(cutReports[scope_counter].begin(), cut) + 1; + + cutflow.SetBinContent(bin, cut->GetPass()); + cutflow.GetXaxis()->SetBinLabel(bin, cut->GetName().c_str()); } - // store it in the output file + cutflow.Write(); } - outputfile.Close(); - TFile *fout = TFile::Open(output.first.c_str(), "UPDATE"); + + // ----------------------------- + // Additional shift maps + // ----------------------------- Logger::get("main")->info("Writing quantities map to {}", output.first); - fout->WriteObject(&shift_quantities_map.at(output.first), - "shift_quantities_map"); - fout->WriteObject(&quantities_shift_map.at(output.first), - "quantities_shift_map"); - fout->Close(); + + outputfile.WriteObject(&shift_quantities_map.at(output.first), + "shift_quantities_map"); + + outputfile.WriteObject(&quantities_shift_map.at(output.first), + "quantities_shift_map"); + + outputfile.Close(); + scope_counter++; } diff --git a/code_generation/analysis_template_friends.cxx b/code_generation/analysis_template_friends.cxx index 1b1615d8..ace2d942 100644 --- a/code_generation/analysis_template_friends.cxx +++ b/code_generation/analysis_template_friends.cxx @@ -211,73 +211,80 @@ int main(int argc, char *argv[]) { const std::string analysis = {ANALYSISTAG}; const std::string config = {CONFIGTAG}; const std::string era = {ERATAG}; - const std::string sample = {SAMPLETAG}; + const std::string sample_type = {SAMPLETAG}; const std::string commit_hash = {COMMITHASH}; bool setup_clean = {CROWN_IS_CLEAN}; const std::string analysis_commit_hash = {ANALYSIS_COMMITHASH}; bool analysis_setup_clean = {ANALYSIS_IS_CLEAN}; int scope_counter = 0; for (auto const &output : output_quantities) { - // output.first is the output file name - // output.second is the list of quantities + TFile outputfile(output.first.c_str(), "UPDATE"); - TTree quantities_meta = TTree("quantities", "quantities"); - for (auto const &quantity : output.second) { - quantities_meta.Branch(quantity.c_str(), &setup_clean); - } - quantities_meta.Write(); - TTree variations_meta = TTree("variations", "variations"); - for (auto const &variation : variations.at(output.first)) { - variations_meta.Branch(variation.c_str(), &setup_clean); - } - variations_meta.Write(); - TTree conditions_meta = TTree("conditions", "conditions"); - conditions_meta.Branch(analysis.c_str(), &setup_clean); - conditions_meta.Branch(config.c_str(), &setup_clean); - conditions_meta.Branch(era.c_str(), &setup_clean); - conditions_meta.Branch(sample.c_str(), &setup_clean); - conditions_meta.Write(); - TTree commit_meta = TTree("commit", "commit"); - commit_meta.Branch(commit_hash.c_str(), &setup_clean); - commit_meta.Fill(); - commit_meta.Write(); - TTree analysis_commit_meta = - TTree("analysis_commit", "analysis_commit"); - analysis_commit_meta.Branch(analysis_commit_hash.c_str(), - &analysis_setup_clean); - analysis_commit_meta.Fill(); - analysis_commit_meta.Write(); + + // ----------------------------- + // Unified metadata container + // ----------------------------- + nlohmann::json j; + + j["metadata"] = { + {"analysis", analysis}, + {"config", config}, + {"era", era}, + {"sample_type", sample_type}, + {"commit", commit_hash}, + {"analysis_commit", analysis_commit_hash}, + {"is_clean", setup_clean}, + {"analysis_is_clean", analysis_setup_clean} + }; + + j["content"] = { + {"quantities", output.second}, + {"variations", variations.at(output.first)} + }; + + TObjString json(j.dump().c_str()); + outputfile.WriteObject(&json, "metadata"); + + // ----------------------------- + // Cutflow + // ----------------------------- if (nevents != 0) { + TH1D cutflow; cutflow.SetName("cutflow"); cutflow.SetTitle("cutflow"); - // iterate through the cutflow vector and fill the histogram with - // the .GetPass() values + if (cutReports.size() < scope_counter || cutReports.empty()) { Logger::get("main")->critical( "cutReports vector is too small, this should not happen"); return 1; } + for (auto cut = cutReports[scope_counter].begin(); - cut != cutReports[scope_counter].end(); cut++) { - cutflow.SetBinContent( - std::distance(cutReports[scope_counter].begin(), cut) + 1, - cut->GetPass()); - cutflow.GetXaxis()->SetBinLabel( - std::distance(cutReports[scope_counter].begin(), cut) + 1, - cut->GetName().c_str()); + cut != cutReports[scope_counter].end(); cut++) { + + int bin = std::distance(cutReports[scope_counter].begin(), cut) + 1; + + cutflow.SetBinContent(bin, cut->GetPass()); + cutflow.GetXaxis()->SetBinLabel(bin, cut->GetName().c_str()); } - // store it in the output file + cutflow.Write(); } - outputfile.Close(); - TFile *fout = TFile::Open(output.first.c_str(), "UPDATE"); + + // ----------------------------- + // Shift maps + // ----------------------------- Logger::get("main")->info("Writing quantities map to {}", output.first); - fout->WriteObject(&shift_quantities_map.at(output.first), - "shift_quantities_map"); - fout->WriteObject(&quantities_shift_map.at(output.first), - "quantities_shift_map"); - fout->Close(); + + outputfile.WriteObject(&shift_quantities_map.at(output.first), + "shift_quantities_map"); + + outputfile.WriteObject(&quantities_shift_map.at(output.first), + "quantities_shift_map"); + + outputfile.Close(); + scope_counter++; } diff --git a/code_generation/code_generation.py b/code_generation/code_generation.py index 6593831e..9fd9bfa8 100644 --- a/code_generation/code_generation.py +++ b/code_generation/code_generation.py @@ -390,12 +390,10 @@ def write_code(self, calls: str, includes: str, run_commands: str) -> None: .replace(" // {RUN_COMMANDS}", run_commands) .replace("// {MULTITHREADING}", threadcall) .replace("// {DEBUGLEVEL}", self.set_debug_flag()) - .replace("{ERATAG}", '"Era={}"'.format(self.configuration.era)) - .replace( - "{SAMPLETAG}", '"Samplegroup={}"'.format(self.configuration.sample) - ) - .replace("{ANALYSISTAG}", '"Analysis={}"'.format(self.analysis_name)) - .replace("{CONFIGTAG}", '"Config={}"'.format(self.config_name)) + .replace("{ERATAG}", '"{}"'.format(self.configuration.era)) + .replace("{SAMPLETAG}", '"{}"'.format(self.configuration.sample)) + .replace("{ANALYSISTAG}", '"{}"'.format(self.analysis_name)) + .replace("{CONFIGTAG}", '"{}"'.format(self.config_name)) .replace("{OUTPUT_QUANTITIES}", self.set_output_quantities()) .replace("{SHIFT_QUANTITIES_MAP}", self.set_shift_quantities_map()) .replace("{QUANTITIES_SHIFT_MAP}", self.set_quantities_shift_map()) diff --git a/code_generation/friend_trees.py b/code_generation/friend_trees.py index 84636249..69dcb9af 100644 --- a/code_generation/friend_trees.py +++ b/code_generation/friend_trees.py @@ -6,7 +6,7 @@ import os from time import time from code_generation.configuration import Configuration -from typing import List, Union, Dict, Set +from typing import List, Union, Dict, Set, Tuple, Any from code_generation.exceptions import ( ConfigurationError, @@ -141,21 +141,22 @@ def _determine_requested_shifts(self, shiftset: Set[str]) -> Dict[str, List[str] def _readout_input_information( self, input_information_list: Union[List[str], List[Dict[str, List[str]]]], + metadata: Dict[str, Any] = {}, ) -> Dict[str, Dict[str, List[str]]]: def update_input_information(existing_data, new_data): - if existing_data == {}: - return new_data - else: - # otherwise we have to merge the contents, while not overwriting existing data - for scope in new_data.keys(): - if scope not in existing_data.keys(): - existing_data[scope] = {} - for shift in new_data[scope].keys(): - if shift not in existing_data[scope].keys(): - existing_data[scope][shift] = [] - for quantity in new_data[scope][shift]: - if quantity not in existing_data[scope][shift]: - existing_data[scope][shift].append(quantity) + # Merge contents, while not overwriting existing data + for scope in new_data.keys(): + if scope not in existing_data.keys(): + existing_data[scope] = {} + for shift in new_data[scope].keys(): + if shift not in existing_data[scope].keys(): + existing_data[scope][shift] = [] + for quantity in new_data[scope][shift]: + if quantity not in existing_data[scope][shift]: + # Add origin config to quantity for naviagation with multifriends + existing_data[scope][shift].append( + (quantity, metadata["config"]) + ) return existing_data # first check if the input is a root file or a json file @@ -164,13 +165,15 @@ def update_input_information(existing_data, new_data): log.info(f"adding input information from {input_information}") if isinstance(input_information, str): if input_information.endswith(".root"): - data = update_input_information( - data, self._readout_input_root_file(input_information) + shift_map, metadata = self._readout_input_root_file( + input_information ) + data = update_input_information(data, shift_map) elif input_information.endswith(".json"): - data = update_input_information( - data, self._readout_input_json_file(input_information) + shift_map, metadata = self._readout_input_json_file( + input_information ) + data = update_input_information(data, shift_map) else: error_message = f"\n Input information file {input_information} is not a json or root file \n" error_message += ( @@ -183,7 +186,7 @@ def update_input_information(existing_data, new_data): def _readout_input_root_file( self, input_file: str - ) -> Dict[str, Dict[str, List[str]]]: + ) -> Tuple[Dict[str, Dict[str, List[str]]], Dict[str, Any]]: """Read the shift_quantities_map from the input root file and return it as a dictionary Args: @@ -207,7 +210,7 @@ def _readout_input_root_file( if not os.path.exists(lib_path): log.error(f"Missing library: {lib_path}") # Evaluate ROOT-specific return codes - result = ROOT.gSystem.Load(lib_path) + result = ROOT.gSystem.Load(lib_path) # type: ignore if result < 0: err_type = ( "Version mismatch" @@ -217,19 +220,19 @@ def _readout_input_root_file( log.error(f"Load failed ({result}): {err_type} for {lib_path}") f = ROOT.TFile.Open(input_file) # type: ignore - name = "shift_quantities_map" - m = f.Get(name) + m = f.Get("shift_quantities_map") for shift, quantities in m: data[str(shift)] = [str(quantity) for quantity in quantities] + metadata = json.loads(f.Get("metadata").GetString().Data()) f.Close() log.debug( f"Reading quantities information took {round(time() - start,2)} seconds" ) - return {list(self.selected_scopes)[0]: data} + return {list(self.selected_scopes)[0]: data}, metadata["metadata"] def _readout_input_json_file( self, input_file: str - ) -> Dict[str, Dict[str, List[str]]]: + ) -> Tuple[Dict[str, Dict[str, List[str]]], Dict[str, Any]]: """Read the shift_quantities_map from the input json file and return it as a dictionary Args: @@ -240,26 +243,33 @@ def _readout_input_json_file( """ with open(input_file) as f: data = json.load(f) + quantity_data = data["quantities"] + metadata = data["metadata"] # json file structure is: {era: {sampletype: {scope: {shift: [quantities]}}} - if self.era not in data: + if self.era not in quantity_data or self.era != metadata["era"]: errorstring = ( f"Era {self.era} not found in input information file {input_file}.\n" ) - errorstring += f"Available eras are: {data.keys()}" + errorstring += f"Available eras are: {quantity_data.keys()}" raise ConfigurationError(errorstring) - if self.sample not in data[self.era].keys(): + if ( + self.sample not in quantity_data[self.era].keys() + or self.sample != metadata["sample_type"] + ): errorstring = f"Sampletype {self.sample} not found in input information file {input_file}.\n" - errorstring += f"Available sampletypes are: {data[self.era].keys()}" + errorstring += ( + f"Available sampletypes are: {quantity_data[self.era].keys()}" + ) raise ConfigurationError(errorstring) if not set(self.selected_scopes).issubset( - set(data[self.era][self.sample].keys()) + set(quantity_data[self.era][self.sample].keys()) ): errorstring = f"Scopes {self.selected_scopes} not found in input information file {input_file}.\n" - errorstring += f"Available scopes are: {data[self.era][self.sample].keys()}" + errorstring += ( + f"Available scopes are: {quantity_data[self.era][self.sample].keys()}" + ) raise ConfigurationError(errorstring) - else: - data = data[self.era][self.sample] - return data + return quantity_data[self.era][self.sample], metadata def optimize(self) -> None: """ @@ -375,7 +385,9 @@ def _validate_inputs(self) -> None: [x.name for x in producer.get_outputs(scope)] ) # get all available inputs - for input_quantitiy in self.input_quantities_mapping[scope][""]: + for input_quantitiy, quantity_origin in self.input_quantities_mapping[ + scope + ][""]: available_inputs.add(input_quantitiy) # now check if all inputs are available missing_inputs = required_inputs - available_inputs diff --git a/code_generation/helpers.py b/code_generation/helpers.py index 75a5ce04..7e17987f 100644 --- a/code_generation/helpers.py +++ b/code_generation/helpers.py @@ -1,9 +1,10 @@ from __future__ import annotations # needed for type annotations in > python 3.7 +from typing import Any # File with helper functions for the CROWN code generation -def is_empty(value): +def is_empty(value: Any) -> bool: """ Check if a value is empty. @@ -13,12 +14,11 @@ def is_empty(value): Returns: bool: Whether the input value is considered 'empty' """ - # List of all values that should be considered empty despite not having a length. empty_values = [None] try: length = len(value) except TypeError: length = -1 - bool_val = value in empty_values or length == 0 - return bool_val + + return value in empty_values or length == 0 diff --git a/code_generation/producer.py b/code_generation/producer.py index 69fa45f8..0332c576 100644 --- a/code_generation/producer.py +++ b/code_generation/producer.py @@ -28,6 +28,7 @@ def __init__( input: Union[List[q.Quantity], Dict[str, List[q.Quantity]]], output: Union[List[q.Quantity], None], scopes: List[str], + is_filter: bool = False, ): """ A Producer is a class that holds all information about a producer. Input quantities are @@ -38,6 +39,7 @@ def __init__( input: A list of input quantities or a dict with scope specific input quantities output: A list of output quantities scopes: A list of scopes in which the producer is used + is_filter: True if the producer is a filter """ log.debug("Setting up a new producer {}".format(name)) @@ -57,6 +59,7 @@ def __init__( self.call: str = call self.output: Union[List[q.Quantity], None] = output self.scopes = scopes + self.is_filter = is_filter self.parameters: Dict[str, Set[str]] = self.extract_parameters() # if input not given as dict and therfore not scope specific transform into dict with all scopes if not isinstance(input, dict): @@ -73,7 +76,10 @@ def __init__( for output_quantity in self.output: input_quantity.adopt(output_quantity, scope) log.debug("-----------------------------------------") - log.debug("| Producer: {}".format(self.name)) + if self.is_filter: + log.debug("| Filter Producer: {}".format(self.name)) + else: + log.debug("| Producer: {}".format(self.name)) log.debug("| Call: {}".format(self.call)) for scope in self.scopes: if is_empty(self.input[scope]): @@ -374,6 +380,7 @@ def __init__( output: Union[List[q.Quantity], None], scopes: List[str], vec_configs: List[str], + is_filter: bool = False, ): """A Vector Producer is a Producer which can be configured to produce multiple calls and outputs at once, deprecated in favor of the ExtendedVectorProducer @@ -384,9 +391,10 @@ def __init__( output (Union[List[q.Quantity], None]): The outputs of the producer, either a list of Quantity objects, or None if the producer does not produce any output scopes (List[str]): The scopes in which the producer is to be called vec_configs (List[str]): A list of strings, which are the names of the parameters to be varied in the vectorized call + is_filter (bool, optional): Whether the vector producer is a filter. Defaults to False. """ self.name = name - super().__init__(name, call, input, output, scopes) + super().__init__(name, call, input, output, scopes, is_filter) self.vec_configs = vec_configs def __str__(self) -> str: @@ -463,6 +471,7 @@ def __init__( output: str, scope: Union[List[str], str], vec_config: str, + is_filter: bool = False, ): """A ExtendedVectorProducer is a Producer which can be configured to produce multiple calls and outputs at once @@ -473,6 +482,7 @@ def __init__( output (Union[List[q.Quantity], None]): The outputs of the producer, either a list of Quantity objects, or None if the producer does not produce any output scopes (List[str]): The scopes in which the producer is to be called vec_configs (List[str]): The key of the vec config in the config dict + is_filter (bool, optional): Whether the vectorproducer is a filter. Defaults to False. """ # we create a Quantity Group, which is updated during the writecalls() step self.outputname = output @@ -482,7 +492,7 @@ def __init__( quantity_group = q.QuantityGroup(name) # set the vec config key of the quantity group quantity_group.set_vec_config(vec_config) - super().__init__(name, call, input, [quantity_group], scope) + super().__init__(name, call, input, [quantity_group], scope, is_filter) if is_empty(self.output): raise InvalidProducerConfigurationError(self.name) # add the vec config to the parameters of the producer @@ -566,7 +576,7 @@ def __init__( either a list of Quantity objects, or a dict with the scope as key and a list of Quantity objects as value scopes (List[str]): The scopes in which the filter is to be called """ - super().__init__(name, call, input, None, scopes) + super().__init__(name, call, input, None, scopes, True) def __str__(self) -> str: return "BaseFilter: {}".format(self.name) diff --git a/code_generation/utility/CROWN_visualization.html b/code_generation/utility/CROWN_visualization.html new file mode 100644 index 00000000..631d67f2 --- /dev/null +++ b/code_generation/utility/CROWN_visualization.html @@ -0,0 +1,995 @@ + + + + + CROWN Graph Explorer + + + + + + + + +
+
+

Controls

+ +
+ +
+
+
+ +
+
+ + + + +
+ +
+ +
+ +
+ Input Quantities +
+ +
+
+
+ +
+ Output Quantities +
+ +
+
+
+ + + +
+
Graph Legend
+
+
+
+ Reads from NanoAOD +
+
+
+ Reads from Ntuple +
+
+
+ Writes to file +
+
+
+ Affected by Shift +
+
+
+ Group +
+
+
+ Vector Group +
+
+
+ Filter +
+
+
+ Unused Producer +
+
+
+
+
+
+ + + +
+
+ + + + \ No newline at end of file diff --git a/code_generation/utility/generate_DAG.py b/code_generation/utility/generate_DAG.py new file mode 100644 index 00000000..830d9f9f --- /dev/null +++ b/code_generation/utility/generate_DAG.py @@ -0,0 +1,878 @@ +import re +import json +import shutil +import logging +from pathlib import Path +from collections import defaultdict +from code_generation.quantity import QuantityGroup +from code_generation.helpers import is_empty +from typing import Never + +log = logging.getLogger(__name__) + + +def log_and_fail(msg: str) -> Never: + log.exception(msg, stack_info=True) + raise RuntimeError(msg) + + +def create_graph(configuration, NanoAOD_inputs, DAG_dir, json_name): + """Instantiate a GraphParser and execute DAG generation. + + Args: + configuration (object): The configuration object containing scopes and producers. + NanoAOD_inputs (list): List of NanoAOD input quantities. + DAG_dir (str): Directory where the generated DAG JSON files will be saved. + json_name (str): The base name for the output JSON file. + """ + for active_scope in configuration.scopes: + # Skip global scope as it is included in all other scopes + if active_scope == "global": + pass + else: + # Add Ntuple quantities for friends + + if type(configuration).__name__ == "FriendTreeConfiguration": + external_inputs = {} + is_friend_config = True + elif type(configuration).__name__ == "Configuration": + external_inputs = {"NanoAOD": NanoAOD_inputs} + is_friend_config = False + else: + log_and_fail( + f"Unknown configuration type: {type(configuration).__name__}" + ) + + if hasattr(configuration, "input_quantities_mapping"): + if not is_friend_config: + log_and_fail( + "input_quantities_mapping is only supported for friend configurations." + ) + external_inputs_tuples = configuration.input_quantities_mapping[ + active_scope + ][""] + config_sources = defaultdict(list) + for quantity, config in external_inputs_tuples: + config_sources[config].append(quantity) + external_inputs.update(config_sources) + shifted_inputs = { + shift: [original for original, *_ in e_in] + for shift, e_in in configuration.input_quantities_mapping[ + active_scope + ].items() + if shift != "" + } + else: + shifted_inputs = None + graph = GraphParser( + configuration, + external_inputs, + active_scope, + DAG_dir, + is_friend_config=is_friend_config, + shifted_inputs=shifted_inputs, + ) + # Generate and save graph for each scope separately + graph.generate_graph() + graph.save_graph(json_name) + + +class GraphParser: + """Parses configuration data to generate Directed Acyclic Graphs (DAGs). + + Represents the dependencies and flow of producers, filters, and I/O. + """ + + def __init__( + self, + config, + external_inputs, + active_scope, + DAG_dir="", + is_friend_config=False, + shifted_inputs=None, + ): + """Initialize the GraphParser object. + + Args: + config (object): The configuration object. + external_inputs (list): List of external inputs required by the graph. + active_scope (str): The specific scope being parsed in addition to the global scope. + DAG_dir (str, optional): The directory for saving DAG files. Defaults to "". + is_friend_config (bool, optional): Indicates if it's a friend configuration. Defaults to False. + """ + self.config = config + self.active_scope = active_scope + if active_scope == "global": + self.scopes = ["global"] + else: + self.scopes = ["global", self.active_scope] + self.external_inputs = external_inputs + self.is_friend_config = is_friend_config + self.shifted_inputs = shifted_inputs + self.connections = defaultdict(lambda: defaultdict(list)) + self.edges = [] + self.inputs = defaultdict(lambda: defaultdict(list)) + self.outputs = defaultdict(lambda: defaultdict(list)) + self.direct_in_to_out = defaultdict(list) + self.vec_output_mappings = {} + self.node_register = defaultdict( + lambda: { + "name": None, + "parent": None, + "type": None, + "file_in": defaultdict(list), + "file_out": [], + "node_call": None, + "node_call_configs": {}, + } + ) + self.shift_registry = defaultdict() + self.DAG_dir = DAG_dir + + def generate_graph(self): + """Execute the main orchestration to generate nodes and edges. + + Iterates through all scopes of the configuration, builds the DAG components, + verifies output ambiguity, bundles edges, and generates formatting metadata + for visualization tools like Cytoscape. + """ + # Determine producers from configuration for global and active scope + for scope in self.scopes: + self.parse_scope_producers(scope) + + # Derive connections from inputs/outputs + self.assemble_connections() + + # Identify nodes without a specific type and without any outgoing connections + # This excludes filters, groups, vector groups, and scopes. + # While they are only defined afterwards this also excludes proxy nodes and edges (branch, twig, leaf). + for node_id, node_data in self.node_register.items(): + if ( + not node_data.get("type") + and is_empty(node_data.get("file_out")) + and node_id not in self.connections + ): + node_data["type"] = "stump" + + self.compile_shift_registry() + + def parse_scope_producers(self, scope): + """Parse producers and outputs for a specific scope. + + Args: + scope (str): The scope name to process. + + Raises: + ValueError: If an output has multiple origins (ambiguous assignments). + """ + if not is_empty(self.config.producers[scope]): + log.debug(f"For scope {scope}:") + # Add top level scope node + self.add_node( + id_name="scope", + name=f"{scope} scope", + scope=scope, + node_type="scope", + ) + + # Parse all producers in this scope + for p in self.config.producers[scope]: + self.parse_Producer_routing( + producer=p, parent="scope", scope=scope, align=" " + ) + + # Check outputs for ambiguous assignments (multiple origins for the same output) + all_keys = list( + set(list(self.outputs["global"].keys()) + list(self.outputs[scope].keys())) + ) + for key in all_keys: + total_output_nodes = self.outputs["global"].get(key, []) + self.outputs[ + scope + ].get(key, []) + if len(set(total_output_nodes)) != 1: + log_and_fail(f"Output {key} has multiple origins: {total_output_nodes}") + + # Determine nodes writing to Ntuple by tracing configured scope outputs back to their producers + if hasattr(self.config, "outputs") and scope in self.config.outputs: + for output in self.config.outputs[scope]: + if isinstance(output, QuantityGroup): + if output.vec_config in self.config.config_parameters.get( + scope, {} + ): + vec_config = self.config.config_parameters[scope][ + output.vec_config + ] + vec_output_name = self.vec_output_mappings.get(output.name) + if vec_output_name: + for o in vec_config: + req_out = o.get(vec_output_name) + if req_out: + self.set_is_out(scope, req_out) + else: + self.set_is_out(scope, output.name) + + def parse_Producer_routing(self, producer, parent, scope, align=""): + """Route parsing logic based on the specific producer class type. + + Args: + producer (object): The producer instance to be parsed. + parent (str): The ID of the parent node. + scope (str): The active scope. + align (str, optional): Spacing string used for debug print alignment. Defaults to "". + + Raises: + NotImplementedError: If the producer's class is unknown. + """ + class_name = producer.__class__.__name__ + + if class_name == "VectorProducer": + self.parse_VectorProducer( + producer=producer, parent=parent, scope=scope, align=align + ) + elif class_name == "ExtendedVectorProducer": + self.parse_ExtendedVectorProducer( + producer=producer, parent=parent, scope=scope, align=align + ) + elif class_name == "ProducerGroup": + self.parse_ProducerGroup( + producer=producer, parent=parent, scope=scope, align=align + ) + elif class_name == "Filter": + self.parse_Filter( + producer=producer, parent=parent, scope=scope, align=align + ) + elif class_name == "BaseFilter": + self.parse_BaseFilter( + producer=producer, parent=parent, scope=scope, align=align + ) + elif class_name == "Producer": + self.parse_Producer( + producer=producer, parent=parent, scope=scope, align=align + ) + else: + log_and_fail(f"Unknown Producer class {class_name}") + + def parse_VectorProducer(self, producer, parent, scope, align=""): + """Map legacy vector configurations to graph nodes and inputs using regex. + + Note: + Currently only supports 'event::filter::Flag'. This is a legacy function + and producers should be migrated to ExtendedVectorProducer. + + Args: + producer (object): The legacy vector producer instance. + parent (str): The ID of the parent node. + scope (str): The active scope. + align (str, optional): Debug print alignment spacing. Defaults to "". + + Raises: + NotImplementedError: If the producer call does not have legacy support. + """ + log.debug(f"{align}Adding VectorProducer: {producer.name}") + log.warning( + f"!!! {producer.name} is a legacy producer and should be replaced with ExtendedVectorProducer !!!" + ) + + # Only accept whitelisted calls for legacy support + if producer.call.startswith("event::filter::Flag"): + self.parse_Flag_from_call( + producer=producer, parent=parent, scope=scope, align=align + ) + else: + log_and_fail(f"The call {producer.call} does not have legacy support.") + + def parse_Flag_from_call(self, producer, parent, scope, align=""): + """Extract vector configurations from a legacy producer's call string using regex. + + Args: + producer (object): The legacy flag producer instance. + parent (str): The ID of the parent node. + scope (str): The active scope. + align (str, optional): Debug print alignment spacing. Defaults to "". + + Raises: + ValueError: If input vector configs cannot be parsed or matched. + NotImplementedError: If list outputs are used (unsupported in legacy). + """ + group_id = f"{producer.name}_v" + call = producer.call + # Pattern consists of: event::filter::Flag(df, "Flag_name", "Flag_name") + pattern = r'event::filter::Flag\({df}, "{(.*)}", "{(.*)}"\)' + match = re.search(pattern, call) + + if match: + input_vec_config = match.group(1) + else: + log_and_fail(f"Input vector config could not be parsed from {call}") + + if input_vec_config not in producer.vec_configs: + log_and_fail( + f"Input name from {call} not in producer vector configs {producer.vec_configs}" + ) + # Determine the index of the input vector config from the call + vec_input_index = producer.vec_configs.index(input_vec_config) + call = producer.call + self.add_node( + id_name=group_id, + name=producer.name, + scope=scope, + parent=parent, + node_type="vector", + ) + vec_configs = [ + self.config.config_parameters[scope][c] for c in producer.vec_configs + ] + # Loop over all vector configurations + for i_c, c in enumerate(zip(*vec_configs)): + input_name = c[vec_input_index] + vector_id = f"{group_id}_{i_c}" + vector_name = f"{producer.name}_{i_c}" + vector_config_dict = {vc: cc for vc, cc in zip(producer.vec_configs, c)} + config_data = self.extract_configs(call, scope, vector_config_dict) + log.debug(f"{align} Adding Filter: {vector_name}") + self.add_node( + id_name=vector_id, + name=vector_name, + scope=scope, + parent=group_id, + is_filter=producer.is_filter, + node_call=call, + node_call_configs=config_data, + ) + self.add_input(input_name, vector_id, scope) + + # outputs are not supported for this legacy producer + if isinstance(producer.output, list): + log_and_fail("List outputs for legacy parsed calls are not supported.") + + def parse_ExtendedVectorProducer(self, producer, parent, scope, align=""): + """Parse an ExtendedVectorProducer class instance into graph nodes, inputs, and outputs. + + Args: + producer (object): The ExtendedVectorProducer instance. + parent (str): The ID of the parent node. + scope (str): The active scope. + align (str, optional): Debug print alignment spacing. Defaults to "". + """ + log.debug(f"{align}Adding ExtendedVectorProducer: {producer.name}") + group_id = f"{producer.name}_v" + self.add_node( + id_name=group_id, + name=producer.name, + scope=scope, + parent=parent, + node_type="vector", + ) + + # ExtendedVectorProducer is better defined and doesn't require regex + call = producer.call + vec_config = self.config.config_parameters[scope][producer.vec_config] + # Loop over all vector configurations + for i_c, c in enumerate(vec_config): + vector_id = f"{group_id}_{i_c}" + vector_name = f"{producer.name}_{i_c}" + config_data = self.extract_configs(call, scope, c) + log.debug(f"{align} Adding Producer: {vector_name}") + self.add_node( + id_name=vector_id, + name=vector_name, + scope=scope, + parent=group_id, + is_filter=producer.is_filter, + node_call=call, + node_call_configs=config_data, + ) + + if isinstance(producer.input[scope], list): + for n in set(producer.input[scope]): + self.add_input(n.name, vector_id, scope) + log.debug(f"{align} Adding Input: {n.name}") + + if isinstance(producer.output, list): + vec_output = c[producer.outputname] + self.vec_output_mappings[producer.output_group.name] = ( + producer.outputname + ) + self.add_output(vec_output, vector_id, scope) + log.debug(f"{align} Adding Output: {vec_output}") + + def parse_ProducerGroup(self, producer, parent, scope, align=""): + """Parse a ProducerGroup class into graph nodes and recursively route its producers. + + Args: + producer (object): The ProducerGroup instance. + parent (str): The ID of the parent node. + scope (str): The active scope. + align (str, optional): Debug print alignment spacing. Defaults to "". + """ + log.debug(f"{align}Adding ProducerGroup: {producer.name}") + + group_id = f"{producer.name}_g" + self.add_node( + id_name=group_id, + name=producer.name, + scope=scope, + parent=parent, + node_type="group", + ) + # Add all producers in group recursively + for p in producer.producers[scope]: + self.parse_Producer_routing( + producer=p, parent=group_id, scope=scope, align=align + " " + ) + + def parse_Filter(self, producer, parent, scope, align=""): + """Parse a Filter class into graph nodes and recursively route its components. + + Args: + producer (object): The Filter instance. + parent (str): The ID of the parent node. + scope (str): The active scope. + align (str, optional): Debug print alignment spacing. Defaults to "". + """ + log.debug(f"{align}Adding Filter Group: {producer.name}") + + group_id = f"{producer.name}_f" + self.add_node( + id_name=group_id, + name=producer.name, + scope=scope, + parent=parent, + node_type="group", + ) + + # Add all filters in group recursively + for p in producer.producers[scope]: + self.parse_Producer_routing( + producer=p, parent=group_id, scope=scope, align=align + " " + ) + + def parse_BaseFilter(self, producer, parent, scope, align=""): + """Parse a legacy BaseFilter class into graph nodes and inputs. + + Args: + producer (object): The legacy BaseFilter instance. + parent (str): The ID of the parent node. + scope (str): The active scope. + align (str, optional): Debug print alignment spacing. Defaults to "". + """ + log.debug(f"{align}Adding BaseFilter: {producer.name}") + log.warning( + f"!!! {producer.name} is a legacy producer and should be replaced with Filter !!!" + ) + call = producer.call + config_data = self.extract_configs(call, scope) + self.add_node( + id_name=producer.name, + scope=scope, + parent=parent, + is_filter=producer.is_filter, + node_call=call, + node_call_configs=config_data, + ) + # BaseFilter don't have outputs + for n in set(producer.input[scope]): + self.add_input(n.name, producer.name, scope) + log.debug(f"{align} Adding Input: {n.name}") + + def parse_Producer(self, producer, parent, scope, align=""): + """Parse a generic Producer class into graph nodes, inputs, and outputs. + + Args: + producer (object): The Producer instance. + parent (str): The ID of the parent node. + scope (str): The active scope. + align (str, optional): Debug print alignment spacing. Defaults to "". + """ + log.debug(f"{align}Adding Producer: {producer.name}") + call = producer.call + config_data = self.extract_configs(call, scope) + self.add_node( + id_name=producer.name, + scope=scope, + parent=parent, + is_filter=producer.is_filter, + node_call=call, + node_call_configs=config_data, + ) + + if isinstance(producer.input[scope], list): + for n in set(producer.input[scope]): + self.add_input(n.name, producer.name, scope) + log.debug(f"{align} Adding Input: {n.name}") + + if isinstance(producer.output, list): + for n in set(producer.output): + self.add_output(n.name, producer.name, scope) + log.debug(f"{align} Adding Output: {n.name}") + + def set_is_out(self, scope, req_out): + """Designate a node as the source of an Ntuple output. + + Args: + scope (str): The active scope. + req_out (str): The name of the requested output quantity. + + Raises: + ValueError: If the number of producers is invalid, the source is missing, + or the output is not provided by NanoAOD/Ntuple. + """ + producers = self.outputs[scope].get(req_out, []) + self.outputs["global"].get( + req_out, [] + ) + if len(producers) > 0: + if len(producers) != 1: + log_and_fail(f"Num producers for out {req_out}: {len(producers)}") + if self.node_register.get(producers[0]): + self.node_register[producers[0]]["file_out"].append(req_out) + else: + log_and_fail( + f"Source {producers[0]} is neither part of {scope} nor global scope." + ) + else: + if req_out in set().union(*self.external_inputs.values()): + if self.is_friend_config: + keys = [ + k + for k, values in self.external_inputs.items() + if req_out in values + ] + if len(keys) > 1: + log.warning( + f"Input {req_out} available from multiple sources {keys}. Picking the first one {keys[0]}." + ) + quantity_source = keys[0] + log.debug( + f"Requested output quantity {req_out} provided by {quantity_source} Ntuple." + ) + self.direct_in_to_out[quantity_source].append(req_out) + else: + log.debug( + f"Requested output quantity {req_out} provided by NanoAOD." + ) + self.direct_in_to_out["NanoAOD"].append(req_out) + else: + log_and_fail( + f"Requested output quantity {req_out} not provided by NanoAOD/Ntuple." + ) + + def assemble_connections(self): + """Resolve mappings to construct the actual connecting edges in the DAG. + + Determines the precise source of each input requirement, whether it + originates from another Producer, NanoAOD, or Ntuple. + + Raises: + ValueError: If an input is entirely missing from both external sources and internal producers. + """ + all_external_inputs = set().union(*self.external_inputs.values()) + for scope in self.scopes: + # Assemble connections by iterating through all nodes + # {target_node: [required_inputs]} is matched to {input: [source_nodes]} + # Connection is of shape {source: {required_inputs: [targets]}} + for target_node, required_inputs in self.inputs[scope].items(): + compose = defaultdict(list) + + # Determine where each input comes from + for req_input in required_inputs: + if self.outputs["global"].get(req_input): + source = self.outputs["global"][req_input][0] + compose[source].append(req_input) + elif self.outputs[scope].get(req_input): + source = self.outputs[scope][req_input][0] + compose[source].append(req_input) + elif req_input in all_external_inputs: + # CROWN friend production may only read quantities from CROWN Ntuples + if self.is_friend_config: + keys = [ + k + for k, values in self.external_inputs.items() + if req_input in values + ] + if len(keys) > 1: + log.warning( + f"Input {req_input} available from multiple sources {keys}. Picking the first one {keys[0]}." + ) + quantity_source = keys[0] + self.node_register[target_node]["file_in"][ + quantity_source + ].append(req_input) + else: + self.node_register[target_node]["file_in"][ + "NanoAOD" + ].append(req_input) + else: + log_and_fail( + f"Input {req_input} is missing from NanoAOD/Ntuple and producers." + ) + + # Create connections grouped by source node + for source_node, input_names in compose.items(): + for input_name in input_names: + log.debug( + f"Adding Connection {input_name} from {source_node} to {target_node}" + ) + self.add_connection( + source=source_node, target=target_node, name=input_name + ) + + def compile_shift_registry(self): + # Aggregate all shifts + # Fails if configurations would be overwritten by another scope + aggregated_shifts = defaultdict(defaultdict) + for scope in self.scopes: + for shift, value in self.config.shifts[scope].items(): + # Strip "__" of shift name + if not shift[0:2] == "__": + log_and_fail(f"Shift names must start with '__' -> Shift {shift}") + shift = shift[2:] + merger = aggregated_shifts[shift] + conflict = merger.keys() & value.keys() and merger != value + if conflict: + log_and_fail(f"Merge conflict on keys: {conflict}") + merger |= value + + for shift, shift_configs in aggregated_shifts.items(): + shift_heads = set() + for shift_cfg in shift_configs: + for node, node_data in self.node_register.items(): + if shift_cfg in node_data["node_call_configs"]: + shift_heads.add(node) + if self.shifted_inputs: + shifted_inputs = self.shifted_inputs[shift] + else: + shifted_inputs = [] + self.shift_registry[shift] = { + "heads": list(shift_heads), + "shift_configs": shift_configs, + "shifted_inputs": shifted_inputs, + } + + def get_downstream(self, source_node, downstream=None): + # Get all downstream nodes + if downstream is None: + downstream = {"edges": set(), "nodes": set()} + for edge, nodes in self.connections[source_node].items(): + downstream["edges"].add(f"edge_{edge}") + for node in nodes: + if node not in downstream["nodes"]: + downstream["nodes"].add(node) + ds_dat = self.get_downstream(node, downstream) + downstream["edges"].update(ds_dat["edges"]) + downstream["nodes"].update(ds_dat["nodes"]) + + return downstream + + def extract_configs(self, call, scope, vector_configs=None, ignore=None): + """Parse out explicit configuration string replacements from a call parameter. + + Args: + call (str): The raw call string containing formatting templates. + scope (str): The active scope context. + vector_configs (dict, optional): Specific vector configurations to prefer. Defaults to None. + ignore (set/list, optional): Strings/keys to ignore when extracting. Defaults to None. + + Returns: + dict: The mapped configurations dict resolving template markers to true values. + + Raises: + ValueError: If an unknown configuration parameter is encountered. + """ + # Ignore parameters that are not config parameters + if is_empty(ignore): + ignore = { + "df", + "output", + "input", + "output_vec", + "input_vec", + "vec_open", + "vec_close", + } + else: + ignore = set(ignore) # type: ignore + + pattern = r"\{(\w+)\}" + matches = re.findall(pattern, call) + + config_parameters = [m for m in matches if m not in ignore] + + # Get config parameter values from vector or general configs + config_dict = {} + for c in config_parameters: + if vector_configs != None and not is_empty(vector_configs.get(c)): + config_dict[c] = vector_configs[c] + elif not is_empty(self.config.config_parameters[scope].get(c)): + config_dict[c] = self.config.config_parameters[scope][c] + else: + log_and_fail(f"Unknown config parameter {c}") + return config_dict + + def add_output(self, output_name, output_node, scope): + """Register a node internally as the generator of a specific output. + + Args: + output_name (str): The name of the output being produced. + output_node (str): The ID of the node producing it. + scope (str): The active scope context. + + Raises: + ValueError: If the node is already registered for this specific output. + """ + node_id = f"{scope}_{output_node}" + if node_id in self.outputs[scope][output_name]: + log_and_fail(f"Node {node_id} already exists in output nodes.") + self.outputs[scope][output_name].append(node_id) + + def add_input(self, input_name, input_node, scope): + """Register a node internally as requiring a specific input dependency. + + Args: + input_name (str): The name of the required input. + input_node (str): The ID of the node needing it. + scope (str): The active scope context. + + Raises: + ValueError: If the input is already registered for this node. + """ + scoped_input = f"{scope}_{input_node}" + if input_name in self.inputs[scope][scoped_input]: + log_and_fail(f"Input {input_name} already exists in inputs.") + self.inputs[scope][scoped_input].append(input_name) + + def add_node( + self, + id_name, + name=None, + scope=None, + parent=None, + node_type=None, + is_filter=False, + family=None, + node_call=None, + node_call_configs=None, + ): + """Create a standardized node object and append it to the graph's internal list. + + Also manages appending metadata to the node and edge family registers. + + Args: + id_name (str): The base string ID of the node. + name (str, optional): Visual label for the node. Defaults to id_name. + scope (str, optional): The active scope context. Defaults to None. + parent (str, optional): The ID of the parent node. Defaults to None. + node_type (str, optional): Specific class or type of the node. Defaults to None. + is_filter (bool, optional): Indicates if the node acts as a filter. Defaults to False. + family (str, optional): Ties the node to a grouping/proxy family. Defaults to None. + node_call_data (dict, optional): Stores string call and extraction data. Defaults to None. + + Raises: + ValueError: If conflicting args (node_type and is_filter) are supplied, + or if a full ID already exists within the standard family register. + """ + if node_type and is_filter: + log_and_fail("Cannot specify both node_type and is_filter") + if is_filter: + node_type = "filter" + if not name: + name = id_name + if scope: + id_name = f"{scope}_{id_name}" + if parent: + if scope: + parent = f"{scope}_{parent}" + if self.node_register.get(id_name): + log_and_fail(f"Node {id_name} already exists in node register") + self.node_register[id_name]["name"] = name + if parent: + self.node_register[id_name]["parent"] = parent + if node_type: + self.node_register[id_name]["type"] = node_type + if node_call: + self.node_register[id_name]["node_call"] = node_call + self.node_register[id_name]["node_call_configs"] = node_call_configs + + def add_connection(self, source, target, name): + """Log a relational connection locally between a source and a target. + + Args: + source (str): ID of the source node providing data. + target (str): ID of the target node consuming data. + name (str): Connection quantity label. + """ + self.connections[source][name].append(target) + + def save_graph(self, name): + """Compile the DAG data structures, inject metadata, and export to a JSON file. + + Also triggers an automatic update to the overarching DAG file tracker. + + Args: + name (str): The base filename used to construct the final exported JSON path. + """ + path = f"{name}_{self.config.era}_{self.config.sample}_{self.active_scope}.json" + if self.DAG_dir: + path = Path(self.DAG_dir) / path + Path(self.DAG_dir).mkdir(parents=True, exist_ok=True) + + # Compile DAG data with metadata + full_data = { + "nodeRegister": self.node_register, + "edgeRegister": self.connections, + "directInputOutput": self.direct_in_to_out, + "shiftRegistry": self.shift_registry, + } + + # Copy visualization file to build dir + script_current_dir = Path(__file__).parent.resolve() + source = script_current_dir / "CROWN_visualization.html" + target = Path(self.DAG_dir) / "index.html" + if not target.exists(): + shutil.copy(source, target) + + # Write DAG data to json + with open(path, "w") as f: + json.dump(full_data, f, indent=4) + log.info( + f"Generated DAG file for {self.config.era}/{self.config.sample}/{self.active_scope}: {path}" + ) + + # Update master DAG file list + self.update_DAG_file_list( + Path(self.DAG_dir) / "DAG_files.json", + self.config.era, + self.config.sample, + self.active_scope, + ) + + def update_DAG_file_list(self, config_path, new_era, new_sample, new_scope): + """Maintain and append to the master JSON manifest tracking generated DAG elements. + + Prevents duplicates while verifying the registration of distinct eras, samples, and scopes. + + Args: + config_path (str): Filepath pointing to the master 'DAG_files.json'. + new_era (str/int): The era identifier to check/add. + new_sample (str): The sample identifier to check/add. + new_scope (str): The scope string to check/add. + """ + if Path(config_path).exists(): + with open(config_path, "r") as f: + try: + data = json.load(f) + except json.JSONDecodeError: + data = {"era": [], "sample": [], "scope": []} + else: + data = {"era": [], "sample": [], "scope": []} + + # Update master DAG file data + data["era"] = sorted(list(set(data["era"] + [str(new_era)]))) + data["sample"] = sorted(list(set(data["sample"] + [str(new_sample)]))) + data["scope"] = sorted(list(set(data["scope"] + [str(new_scope)]))) + + # Write updated master DAG file data back + with open(config_path, "w") as f: + json.dump(data, f, indent=4) + + log.debug(f"Updated {config_path} with {new_era}/{new_sample}/{new_scope}")