diff --git a/func_adl_uproot/transformer.py b/func_adl_uproot/transformer.py index e723c32..00b5a38 100644 --- a/func_adl_uproot/transformer.py +++ b/func_adl_uproot/transformer.py @@ -8,6 +8,7 @@ input_filenames_argument_name = 'input_filenames' tree_name_argument_name = 'tree_name' +branch_filter_name = '_remove_not_interpretable' unary_op_dict = {ast.UAdd: '+', ast.USub: '-', ast.Invert: '~'} @@ -437,7 +438,9 @@ def visit_Call(self, node): + "(logging.getLogger(__name__).info('Using treename='" + ' + repr(tree_name_to_use)),' + ' uproot.dask({input_file: tree_name_to_use' - + ' for input_file in input_files}))[1])' + + ' for input_file in input_files}, filter_branch=' + + branch_filter_name + + '))[1])' + '(' + source_rep + ', ' diff --git a/func_adl_uproot/translation.py b/func_adl_uproot/translation.py index 0ed20f7..b9b21b5 100644 --- a/func_adl_uproot/translation.py +++ b/func_adl_uproot/translation.py @@ -3,7 +3,39 @@ import qastle from .transformer import PythonSourceGeneratorTransformer -from .transformer import input_filenames_argument_name, tree_name_argument_name +from .transformer import branch_filter_name, input_filenames_argument_name, tree_name_argument_name + +# Adapted from https://github.com/CoffeaTeam/coffea/blob/v2024.4.0/src/coffea/util.py#L217-L248 +remove_not_interpretable_source = ( + ' def ' + + branch_filter_name + + '''(branch): + if isinstance(branch.interpretation, uproot.interpretation.identify.uproot.AsGrouped): + for name, interpretation in branch.interpretation.subbranches.items(): + if isinstance( + interpretation, uproot.interpretation.identify.UnknownInterpretation + ): + logging.getLogger(__name__).warning( + f"Skipping {branch.name} as it is not interpretable by Uproot" + ) + return False + if isinstance(branch.interpretation, uproot.interpretation.identify.UnknownInterpretation): + logging.getLogger(__name__).warning( + f"Skipping {branch.name} as it is not interpretable by Uproot" + ) + return False + try: + _ = branch.interpretation.awkward_form(None) + except uproot.interpretation.objects.CannotBeAwkward: + logging.getLogger(__name__).warning( + f"Skipping {branch.name} as it cannot be represented as an Awkward array" + ) + return False + else: + return True + +''' +) def python_ast_to_python_source(python_ast): @@ -26,7 +58,8 @@ def generate_python_source(ast, function_name='run_query'): + '=None):\n' ) source += ' import functools, logging, numpy as np, dask_awkward as dak, uproot, vector\n' - source += ' vector.register_awkward()\n' + source += ' vector.register_awkward()\n\n' + source += remove_not_interpretable_source source += ' return ' + python_ast_to_python_source(python_ast) + '.compute()\n' return source