Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 31 additions & 47 deletions erfa_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def dtype(self):
return "dt_pv"
case "double", (2,):
return "dt_pvdpv"
case _, ():
case (_, ()) | ("eraLDBODY", _):
return super().dtype
raise ValueError(f"ctype {self.ctype} with shape {self.shape} not recognized.")

Expand Down Expand Up @@ -362,71 +362,55 @@ def __init__(self, func: Function, t_erfa_c: str) -> None:
if search is None:
raise RuntimeError(f"cannot find the test for {func.name}")
source = re.sub(r"\s\s+", " ", search.group(1))
self.lines: Final = re.findall(r"\s(.*?);", source, re.DOTALL)
self.definitions: Final = []
self.lines: Final = []
for line in re.findall(r" (.*?);", source):
if line.startswith(("double", "int", "char", "eraASTROM", "eraLDBODY")):
self.definitions.append(line.split(" ", 1))
else:
self.lines.append(line)
self.dt_pv_vars: Final = frozenset(re.findall(r"(\w+)\[2\]\[3\]", source))

def define_arrays(self, line):
"""Check variable definition line for items also needed in python.

E.g., creating an empty astrom structured array.
"""
def process_definitions(self) -> list[str]:
defines = []
# Split line in type and variables.
# E.g., "double x, y, z" will give ctype='double; variables='x, y, z'
ctype, _, variables = line.partition(' ')
for var in variables.split(','):
var = var.strip()
# Is variable an array?
name, _, rest = var.partition('[')
# If not, or one of iymdf or ihmsf, ignore (latter are outputs only).
if not rest or rest[:2] == '4]':
for ctype, variables in self.definitions:
if variables != (
numbers := variables.removeprefix("xyz[] = {").removesuffix("}")
): # Complete hack for single occurrence.
defines.append(f"xyz = np.array([{numbers}])")
continue
if ctype == 'eraLDBODY':
# Special case, since this should be recarray for access similar
# to C struct.
v_dtype = 'dt_eraLDBODY'
v_shape = rest[:rest.index(']')]
extra = ".view(np.recarray)"
else:
for var in variables.split(", "):
if "=" in var: # only happens for double
defines.append(var)
# Is variable an array?
name, _, rest = var.partition("[")
# If not, or one of iymdf or ihmsf, ignore (latter are outputs only).
if not rest or rest[:2] == "4]":
continue
# Temporarily create an Argument, so we can use its attributes.
# This translates, e.g., double pv[2][3] to dtype dt_pv.
v = Argument(ctype + ' ' + var.strip(), '')
v_dtype = v.dtype
v_shape = v.shape if v.signature_shape != '()' else '()'
extra = ""
v_dtype = "float" if v_dtype == "dt_double" else "erfa_ufunc." + v_dtype
defines.append(f"{name} = np.empty({v_shape}, {v_dtype}){extra}")

v = Argument(f"{ctype} {var}", FunctionDoc("", self.func.pyname))
shape = v.shape if v.signature_shape != "()" else "()"
dtype = "float" if v.dtype == "dt_double" else "erfa_ufunc." + v.dtype
defines.append(f"{name} = np.empty({shape}, {dtype})")
if ctype == "eraLDBODY":
# Special case, since this should be recarray for access similar
# to C struct.
defines[-1] += ".view(np.recarray)"
return defines

def to_python(self):
"""Lines defining the body of a python version of the test function."""
# TODO: this is quite hacky right now! Would be good to let function
# calls be understood by the Function class.

out = []
out = self.process_definitions()
for line in self.lines:
# In ldn ufunc, the number of bodies is inferred from the array size,
# so no need to keep the definition.
if line == "n = 3" and self.func.pyname == "ldn":
continue

# Are we dealing with a variable definition that also sets it?
# (hack: only happens for double).
if line.startswith('double') and '=' in line:
# Complete hack for single occurrence.
if line.startswith('double xyz[] = {'):
out.append(f"xyz = np.array([{line[16:-1]}])")
else:
# Put each definition on a separate line.
out.extend([part.strip() for part in line[7:].split(',')])
continue

# Variable definitions: add empty array definition as needed.
if line.startswith(('double', 'int', 'char', 'eraASTROM', 'eraLDBODY')):
out.extend(self.define_arrays(line))
continue

# Actual function. Start with basic replacements.
line = (line
.replace('ERFA_', 'erfa.')
Expand Down
Loading