diff --git a/erfa_generator.py b/erfa_generator.py index 3c1315d..bd854be 100644 --- a/erfa_generator.py +++ b/erfa_generator.py @@ -171,7 +171,7 @@ def dtype(self): return "dt_pv" case "double", (2,): return "dt_pvdpv" - case _, (): + case (_, ()) | ("eraLDBODY", _): return super().dtype raise ValueError(f"ctype {self.ctype} with shape {self.shape} not recognized.") @@ -362,41 +362,41 @@ def __init__(self, func: Function, t_erfa_c: str) -> None: if search is None: raise RuntimeError(f"cannot find the test for {func.name}") source = re.sub(r"\s\s+", " ", search.group(1)) - self.lines: Final = re.findall(r"\s(.*?);", source, re.DOTALL) + self.definitions: Final = [] + self.lines: Final = [] + for line in re.findall(r" (.*?);", source): + if line.startswith(("double", "int", "char", "eraASTROM", "eraLDBODY")): + self.definitions.append(line.split(" ", 1)) + else: + self.lines.append(line) self.dt_pv_vars: Final = frozenset(re.findall(r"(\w+)\[2\]\[3\]", source)) - def define_arrays(self, line): - """Check variable definition line for items also needed in python. - - E.g., creating an empty astrom structured array. - """ + def process_definitions(self) -> list[str]: defines = [] - # Split line in type and variables. - # E.g., "double x, y, z" will give ctype='double; variables='x, y, z' - ctype, _, variables = line.partition(' ') - for var in variables.split(','): - var = var.strip() - # Is variable an array? - name, _, rest = var.partition('[') - # If not, or one of iymdf or ihmsf, ignore (latter are outputs only). - if not rest or rest[:2] == '4]': + for ctype, variables in self.definitions: + if variables != ( + numbers := variables.removeprefix("xyz[] = {").removesuffix("}") + ): # Complete hack for single occurrence. + defines.append(f"xyz = np.array([{numbers}])") continue - if ctype == 'eraLDBODY': - # Special case, since this should be recarray for access similar - # to C struct. - v_dtype = 'dt_eraLDBODY' - v_shape = rest[:rest.index(']')] - extra = ".view(np.recarray)" - else: + for var in variables.split(", "): + if "=" in var: # only happens for double + defines.append(var) + # Is variable an array? + name, _, rest = var.partition("[") + # If not, or one of iymdf or ihmsf, ignore (latter are outputs only). + if not rest or rest[:2] == "4]": + continue # Temporarily create an Argument, so we can use its attributes. # This translates, e.g., double pv[2][3] to dtype dt_pv. - v = Argument(ctype + ' ' + var.strip(), '') - v_dtype = v.dtype - v_shape = v.shape if v.signature_shape != '()' else '()' - extra = "" - v_dtype = "float" if v_dtype == "dt_double" else "erfa_ufunc." + v_dtype - defines.append(f"{name} = np.empty({v_shape}, {v_dtype}){extra}") - + v = Argument(f"{ctype} {var}", FunctionDoc("", self.func.pyname)) + shape = v.shape if v.signature_shape != "()" else "()" + dtype = "float" if v.dtype == "dt_double" else "erfa_ufunc." + v.dtype + defines.append(f"{name} = np.empty({shape}, {dtype})") + if ctype == "eraLDBODY": + # Special case, since this should be recarray for access similar + # to C struct. + defines[-1] += ".view(np.recarray)" return defines def to_python(self): @@ -404,29 +404,13 @@ def to_python(self): # TODO: this is quite hacky right now! Would be good to let function # calls be understood by the Function class. - out = [] + out = self.process_definitions() for line in self.lines: # In ldn ufunc, the number of bodies is inferred from the array size, # so no need to keep the definition. if line == "n = 3" and self.func.pyname == "ldn": continue - # Are we dealing with a variable definition that also sets it? - # (hack: only happens for double). - if line.startswith('double') and '=' in line: - # Complete hack for single occurrence. - if line.startswith('double xyz[] = {'): - out.append(f"xyz = np.array([{line[16:-1]}])") - else: - # Put each definition on a separate line. - out.extend([part.strip() for part in line[7:].split(',')]) - continue - - # Variable definitions: add empty array definition as needed. - if line.startswith(('double', 'int', 'char', 'eraASTROM', 'eraLDBODY')): - out.extend(self.define_arrays(line)) - continue - # Actual function. Start with basic replacements. line = (line .replace('ERFA_', 'erfa.')