diff --git a/src/pycalphad_xml/parser.py b/src/pycalphad_xml/parser.py index b51f6ae..d89cee7 100644 --- a/src/pycalphad_xml/parser.py +++ b/src/pycalphad_xml/parser.py @@ -11,20 +11,21 @@ from pathlib import Path this_dir = Path(__file__).parent +def _stringify_node_text(node): + return ''.join(node.xpath('./text()')).replace('\n', '').replace(' ', '').strip() -def convert_math_to_symbolic(math_nodes): + +def convert_math_to_symbolic(math_node): result = 0.0 - interval_nodes = [x for x in math_nodes if (not isinstance(x, str)) and x.tag == 'Interval'] - string_nodes = [x for x in math_nodes if isinstance(x, str)] - for math_node in string_nodes: - # +0 is a hack, for how the function works - result += _sympify_string(math_node+'+0') - result += convert_intervals_to_piecewise(interval_nodes) + interval_nodes = [x for x in math_node if x.tag == 'Interval'] + # +0 is a hack, for how the function works + result += _sympify_string(_stringify_node_text(math_node)+'+0') + result += _convert_intervals_to_piecewise(interval_nodes) result = result.xreplace({Symbol('T'): v.T, Symbol('P'): v.P}) return result -def convert_intervals_to_piecewise(interval_nodes): +def _convert_intervals_to_piecewise(interval_nodes): exprs = [] conds = [] for interval_node in interval_nodes: @@ -33,7 +34,7 @@ def convert_intervals_to_piecewise(interval_nodes): variable = interval_node.attrib['in'] lower = float(interval_node.attrib.get('lower', '-inf')) upper = float(interval_node.attrib.get('upper', 'inf')) - math_expr = convert_math_to_symbolic([''.join(interval_node.itertext()).replace('\n', '').replace(' ', '').strip()]) + math_expr = _sympify_string(_stringify_node_text(interval_node)+'+0') if upper != float('inf'): cond = And(lower <= getattr(v, variable, Symbol(variable)), upper > getattr(v, variable)) else: @@ -181,9 +182,7 @@ def parse_model(dbf, phase_name, model_node, parameters): constituent_array = [[str(c) for c in sorted(lx)] for lx in constituent_array] # Parameter value - # Interval _and_ text (if any) to be able to handle intervals or scalar expressions - param_nodes = param_node.xpath('./Interval') + [''.join(param_node.xpath('./text()')).strip()] - function_obj = convert_math_to_symbolic(param_nodes) + function_obj = convert_math_to_symbolic(param_node) # TODO: Reference @@ -254,9 +253,7 @@ def read_xml(dbf, fd): dbf.species.add(v.Species(species, constituent_dict, charge=species_charge)) elif child.tag == 'Expr': function_name = str(child.attrib['id']) - # Interval _and_ text (if any) to be able to handle intervals or scalar expressions - expr_nodes = child.xpath('./Interval') + [''.join(child.xpath('./text()')).strip()] - function_obj = convert_math_to_symbolic(expr_nodes) + function_obj = convert_math_to_symbolic(child) _setitem_raise_duplicates(dbf.symbols, function_name, function_obj) elif child.tag == 'Phase': model_nodes = child.xpath('./Model')