diff --git a/erfa/core.py.templ b/erfa/core.py.templ index 33d4c3f..79be4cc 100644 --- a/erfa/core.py.templ +++ b/erfa/core.py.templ @@ -158,34 +158,7 @@ def {{ func.pyname }}({{ func.args_by_inout('in|inout')|map(attribute='name')|jo {{ func.doc.doc | indent(6, true) }} """ - - {#- - # Call the ufunc. Note that we pass inout twice, once as input - # and once as output, so that changes are done in-place - #} - {{ func.python_call }} - {#- - # Check whether any warnings or errors occurred. - #} - {%- for arg in func.args_by_inout('stat') %} - check_errwarn({{ arg.name }}, '{{ func.pyname }}') - {%- endfor %} - {#- - # Any string outputs will be in structs; view them as their base type. - #} - {%- for arg in func.args_by_inout('out') -%} - {%- if arg.ctype == "char" %} - {{ arg.name }} = {{ arg.name }}.view(dt_bytes1) - {%- endif %} - {%- endfor %} - {#- - # Return the output arguments (including the inplace ones) - #} - {%- if func.result_tuple %} - return {{ func.result_tuple.create() }} - {%- else %} - return {{ func.args_by_inout('inout|out|ret')[0].name }} - {%- endif %} + {{ func.generate_python_body() | indent(4) }} {%- endfor %} {# done! (note: this comment also ensures final new line!) #} diff --git a/erfa_generator.py b/erfa_generator.py index 62d446e..8701f23 100644 --- a/erfa_generator.py +++ b/erfa_generator.py @@ -351,6 +351,32 @@ def python_call(self): return ('(' + result[:split_point] + '\n ' + result[split_point:].replace(' =', ') =')) + def generate_python_body(self) -> str: + lines = [self.python_call] + if status_code := self.args_by_inout("stat"): + lines.append(f"check_errwarn({status_code[0].name}, {self.pyname!r})") + lines.extend( + f"{arg.name} = {arg.name}.view(dt_bytes1)" + for arg in self.args_by_inout("out") + if arg.ctype == "char" + ) + if len(lines) == 1: + arg_names = [arg.name for arg in self.args_by_inout("in|inout")] + ufunc_call = f"ufunc.{self.pyname}({', '.join(arg_names)})" + ret_val = ( + ufunc_call + if self.result_tuple is None + else f"{self.result_tuple.name}(*{ufunc_call})" + ) + return f"return {ret_val}" + ret_val = ( + self.args_by_inout("inout|out|ret")[0].name + if self.result_tuple is None + else self.result_tuple.create() + ) + lines.append(f"return {ret_val}") + return "\n".join(lines) + class Constant: