From 90ff7fc4003cfafbb0758b59c0ddb9a9817a8474 Mon Sep 17 00:00:00 2001 From: Mason Proffitt Date: Mon, 8 Apr 2024 17:00:26 +0200 Subject: [PATCH] skip branches that uproot/awkward can't handle #82 --- func_adl_uproot/transformer.py | 3 ++- func_adl_uproot/translation.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/func_adl_uproot/transformer.py b/func_adl_uproot/transformer.py index e723c32..5e77347 100644 --- a/func_adl_uproot/transformer.py +++ b/func_adl_uproot/transformer.py @@ -8,6 +8,7 @@ input_filenames_argument_name = 'input_filenames' tree_name_argument_name = 'tree_name' +branch_filter_name = '_remove_not_interpretable' unary_op_dict = {ast.UAdd: '+', ast.USub: '-', ast.Invert: '~'} @@ -437,7 +438,7 @@ def visit_Call(self, node): + "(logging.getLogger(__name__).info('Using treename='" + ' + repr(tree_name_to_use)),' + ' uproot.dask({input_file: tree_name_to_use' - + ' for input_file in input_files}))[1])' + + ' for input_file in input_files}, filter_branch=' + branch_filter_name+ '))[1])' + '(' + source_rep + ', ' diff --git a/func_adl_uproot/translation.py b/func_adl_uproot/translation.py index 0ed20f7..285f5a7 100644 --- a/func_adl_uproot/translation.py +++ b/func_adl_uproot/translation.py @@ -3,7 +3,33 @@ import qastle from .transformer import PythonSourceGeneratorTransformer -from .transformer import input_filenames_argument_name, tree_name_argument_name +from .transformer import branch_filter_name, input_filenames_argument_name, tree_name_argument_name + +# Adapted from https://github.com/CoffeaTeam/coffea/blob/e2cd5e291e90314b619a40a1ecd2649f1b2de00f/src/coffea/util.py#L217-L248 +remove_not_interpretable_source = ''' def ''' + branch_filter_name + '''(branch): + if isinstance(branch.interpretation, uproot.interpretation.identify.uproot.AsGrouped): + for name, interpretation in branch.interpretation.subbranches.items(): + if isinstance(interpretation, uproot.interpretation.identify.UnknownInterpretation): + logging.getLogger(__name__).warning( + f"Skipping {branch.name} as it is not interpretable by Uproot" + ) + return False + if isinstance(branch.interpretation, uproot.interpretation.identify.UnknownInterpretation): + logging.getLogger(__name__).warning( + f"Skipping {branch.name} as it is not interpretable by Uproot" + ) + return False + try: + _ = branch.interpretation.awkward_form(None) + except uproot.interpretation.objects.CannotBeAwkward: + logging.getLogger(__name__).warning( + f"Skipping {branch.name} as it cannot be represented as an Awkward array" + ) + return False + else: + return True + +''' def python_ast_to_python_source(python_ast): @@ -26,7 +52,8 @@ def generate_python_source(ast, function_name='run_query'): + '=None):\n' ) source += ' import functools, logging, numpy as np, dask_awkward as dak, uproot, vector\n' - source += ' vector.register_awkward()\n' + source += ' vector.register_awkward()\n\n' + source += remove_not_interpretable_source source += ' return ' + python_ast_to_python_source(python_ast) + '.compute()\n' return source