Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 52 additions & 12 deletions src/mccode_antlr/run/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,43 @@ def float_or_int_or_str(s):
raise ValueError(f'Singular string {string} contains a colon')
return cls(float_or_int_or_str(string))

class EList:
"""An explicit list of values for a parameter."""
def __init__(self, values: list):
self.values = values

def __eq__(self, other: 'EList'):
return all(v == o for v, o in zip(self.values, other.values, strict=True))

def __iter__(self):
return iter(self.values)

def __getitem__(self, index: int):
if index < 0 or index >= len(self):
raise IndexError(f'Index {index} out of range')
return self.values[index]

def __str__(self):
return ','.join(str(v) for v in self.values)

def __repr__(self):
return f'EList({self})'

def __len__(self):
return len(self.values)

@classmethod
def from_str(cls, string):
"""Parse an explicit list string"""
def float_or_int(s):
try:
return int(s)
except ValueError:
pass
return float(s)

return cls([float_or_int(s) for s in string.split(',')])


def parse_list(range_type, unparsed: list[str]):
ranges = {}
Expand All @@ -154,11 +191,11 @@ def parse_list(range_type, unparsed: list[str]):
return ranges


def parameters_to_scan(parameters: dict[str, Union[list, MRange, Singular]], grid: bool = False):
def parameters_to_scan(parameters: dict[str, Union[list, MRange, EList, Singular]], grid: bool = False):
"""Convert a dictionary of ranged parameters to a list of parameter names and an iterable of parameter value tuples.

The ranged parameters can be either MRange objects or lists of values. If a list of values is provided, it will be
iterated over directly.
The ranged parameters can be MRange, EList, Singular objects or lists of values.
If a list of values is provided, it will be iterated over directly.

:parameter parameters: A dictionary of ranged parameters.
:parameter grid: Controls how the parameters are iterated; True implies a grid scan, False implies a linear scan.
Expand Down Expand Up @@ -190,13 +227,15 @@ def parameters_to_scan(parameters: dict[str, Union[list, MRange, Singular]], gri
return n_max, names, zip(*[v if len(v) > 1 else Singular(v[0] if isinstance(v, MRange) else v.value, n_max) for v in values])


def _MRange_or_Singular(s: str):
def _make_scanned_parameter(s: str):
if ':' in s:
return MRange.from_str(s)
elif ',' in s:
return EList.from_str(s)
return Singular.from_str(s)


def parse_command_line_parameters(unparsed: list[str]) -> dict[str, Union[Singular, MRange]]:
def parse_command_line_parameters(unparsed: list[str]) -> dict[str, Union[Singular, EList, MRange]]:
"""Parse a list of input parameters into a dictionary of MRange objects.

:parameter unparsed: A list of parameters.
Expand All @@ -207,25 +246,26 @@ def parse_command_line_parameters(unparsed: list[str]) -> dict[str, Union[Singul
while index < len(unparsed):
if '=' in unparsed[index]:
k, v = unparsed[index].split('=', 1)
ranges[k] = _MRange_or_Singular(v)
ranges[k] = _make_scanned_parameter(v)
elif index + 1 < len(unparsed) and '=' not in unparsed[index + 1]:
ranges[unparsed[index]] = _MRange_or_Singular(unparsed[index + 1])
ranges[unparsed[index]] = _make_scanned_parameter(unparsed[index + 1])
index += 1
else:
raise ValueError(f'Invalid parameter: {unparsed[index]}')
index += 1
return ranges


def parse_scan_parameters(unparsed: list[str]) -> dict[str, MRange | Singular]:
"""Parse a list of input parameters into a dictionary of MRange or Singular objects.
def parse_scan_parameters(unparsed: list[str]) -> dict[str, MRange | EList | Singular]:
"""Parse a list of input parameters into a dictionary of MRange, EList, or Singular objects.

:parameter unparsed: A list of parameters.
:return: A dictionary of MRange or Singular objects. The Singular objects have their maximum length set to the
maximum iterations of all the ranges to avoid infinite iterations.
:return: A dictionary of MRange, EList, or Singular objects.
The Singular objects have their maximum length set to the maximum iterations
of all the ranges to avoid infinite iterations.
"""
ranges = parse_command_line_parameters(unparsed)
max_length = max(len(v) if isinstance(v, MRange) else 1 for v in ranges.values()) if len(ranges) else 1
max_length = max(1 if isinstance(v, Singular) else len(v) for v in ranges.values()) if len(ranges) else 1
for k, v in ranges.items():
if isinstance(v, Singular) and v.maximum is None:
ranges[k] = Singular(v.value, max_length)
Expand Down
Loading