Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a165a06

Browse files
committedFeb 20, 2025
Create TOML V2
Given that the new things we're adding to the user file aren't compatible with older versions of mantid, we're bumping up to TOML V2 so that we can keep being backwards compatible. This is mainly to stop new user files containing this information from being loaded by older versions, making it clear that you need to upgrade to use these new features and toml files. RE mantidproject#38524
1 parent e935450 commit a165a06

17 files changed

+513
-310
lines changed
 

‎scripts/SANS/sans/gui_logic/presenter/run_tab_presenter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Optional
1919

2020
from mantidqt.utils.observer_pattern import GenericObserver
21-
from sans.user_file.toml_parsers.toml_v1_schema import TomlValidationError
21+
from sans.user_file.toml_parsers.toml_base_schema import TomlValidationError
2222
from ui.sans_isis import SANSSaveOtherWindow
2323
from ui.sans_isis.sans_data_processor_gui import SANSDataProcessorGui
2424
from ui.sans_isis.sans_gui_observable import SansGuiObservable

‎scripts/SANS/sans/state/StateObjects/StatePolarization.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self):
5858
# Relevant to He3 filters (Polarisers and Analysers)
5959
self.cell_length = None
6060
self.gas_pressure = None
61+
self.empty_cell = None
6162
# Relevant to all polarisers and analysers.
6263
self.initial_polarization = None
6364

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Mantid Repository : https://github.com/mantidproject/mantid
2+
#
3+
# Copyright © 2025 ISIS Rutherford Appleton Laboratory UKRI,
4+
# NScD Oak Ridge National Laboratory, European Spallation Source
5+
# & Institut Laue - Langevin
6+
# SPDX - License - Identifier: GPL - 3.0 +
7+
8+
from typing import List, Dict
9+
10+
from sans.common.enums import SANSInstrument, SANSFacility
11+
from sans.state.StateObjects.StateData import get_data_builder
12+
from sans.state.StateObjects.StateData import StateData
13+
from sans.test_helper.file_information_mock import SANSFileInformationMock
14+
15+
16+
def get_mock_data_info():
17+
# TODO I really really dislike having to do this in a test, but
18+
# TODO de-coupling StateData is required to avoid it
19+
file_information = SANSFileInformationMock(instrument=SANSInstrument.SANS2D, run_number=22024)
20+
data_builder = get_data_builder(SANSFacility.ISIS, file_information)
21+
data_builder.set_sample_scatter("SANS2D00022024")
22+
data_builder.set_sample_scatter_period(3)
23+
return data_builder.build()
24+
25+
26+
def setup_parser_dict(dict_vals) -> tuple[dict, StateData]:
27+
def _add_missing_mandatory_key(dict_to_check: Dict, key_path: List[str], replacement_val):
28+
_dict = dict_to_check
29+
for key in key_path[0:-1]:
30+
if key not in _dict:
31+
_dict[key] = {}
32+
_dict = _dict[key]
33+
34+
if key_path[-1] not in _dict:
35+
_dict[key_path[-1]] = replacement_val # Add in child value
36+
return dict_to_check
37+
38+
mocked_data_info = get_mock_data_info()
39+
# instrument key needs to generally be present
40+
dict_vals = _add_missing_mandatory_key(dict_vals, ["instrument", "name"], "LOQ")
41+
dict_vals = _add_missing_mandatory_key(dict_vals, ["detector", "configuration", "selected_detector"], "rear")
42+
43+
return dict_vals, mocked_data_info
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Mantid Repository : https://github.com/mantidproject/mantid
2+
#
3+
# Copyright © 2025 ISIS Rutherford Appleton Laboratory UKRI,
4+
# NScD Oak Ridge National Laboratory, European Spallation Source
5+
# & Institut Laue - Langevin
6+
# SPDX - License - Identifier: GPL - 3.0 +
7+
8+
from abc import ABCMeta, abstractmethod
9+
10+
from sans.state.IStateParser import IStateParser
11+
from sans.user_file.toml_parsers.toml_base_schema import TomlSchemaValidator
12+
13+
14+
class TomlParserBase(IStateParser, metaclass=ABCMeta):
15+
def __init__(self, dict_to_parse, file_information, schema_validator: TomlSchemaValidator):
16+
self._validator = schema_validator
17+
self._validator.validate()
18+
19+
self._implementation = None
20+
data_info = self.get_state_data(file_information)
21+
self._implementation = self._get_impl(dict_to_parse, data_info)
22+
self._implementation.parse_all()
23+
24+
@staticmethod
25+
@abstractmethod
26+
def _get_impl(*args):
27+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Mantid Repository : https://github.com/mantidproject/mantid
2+
#
3+
# Copyright © 2025 ISIS Rutherford Appleton Laboratory UKRI,
4+
# NScD Oak Ridge National Laboratory, European Spallation Source
5+
# & Institut Laue - Langevin
6+
# SPDX - License - Identifier: GPL - 3.0 +
7+
8+
import re
9+
from abc import abstractmethod, ABCMeta
10+
11+
12+
class TomlValidationError(Exception):
13+
# A key error would be more appropriate, but there is special
14+
# handling on KeyError which escapes any new lines making it un-readable
15+
pass
16+
17+
18+
class TomlSchemaValidator(object, metaclass=ABCMeta):
19+
# As of the current TOML release there is no way to validate a schema so
20+
# we must provide an implementation
21+
22+
def __init__(self, dict_to_validate):
23+
self._expected_list = self._build_nested_keys(self.reference_schema())
24+
self._to_validate_list = self._build_nested_keys(dict_to_validate)
25+
26+
def validate(self):
27+
self._to_validate_list = filter(lambda s: not s.startswith("metadata"), self._to_validate_list)
28+
unrecognised = set(self._to_validate_list).difference(self._expected_list)
29+
30+
if not unrecognised:
31+
return
32+
33+
# Build any with wildcards
34+
wildcard_matchers = [re.compile(s) for s in self._expected_list if "*" in s]
35+
# Remove anything which matches any the regex wildcards
36+
unrecognised = [s for s in unrecognised if not any(wild_matcher.match(s) for wild_matcher in wildcard_matchers)]
37+
38+
if len(unrecognised) > 0:
39+
err = "The following keys were not recognised:\n"
40+
err += "".join("{0} \n".format(k) for k in unrecognised)
41+
raise TomlValidationError(err)
42+
43+
@staticmethod
44+
@abstractmethod
45+
def reference_schema():
46+
"""
47+
Returns a dictionary layout of all supported keys
48+
:return: Dictionary containing all keys, and values set to None
49+
"""
50+
pass
51+
52+
@staticmethod
53+
def _build_nested_keys(d, path="", current_out=None):
54+
if not current_out:
55+
current_out = []
56+
57+
def make_path(current_path, new_key):
58+
return current_path + "." + new_key if current_path else new_key
59+
60+
for key, v in d.items():
61+
new_path = make_path(path, key)
62+
if isinstance(v, dict):
63+
# Recurse into dict
64+
current_out = TomlSchemaValidator._build_nested_keys(v, new_path, current_out)
65+
elif isinstance(v, set):
66+
# Pack all in from the set of names
67+
for name in v:
68+
current_out.append(make_path(new_path, name))
69+
else:
70+
# This means its a value type with nothing special, so keep name
71+
current_out.append(new_path)
72+
73+
return current_out

‎scripts/SANS/sans/user_file/toml_parsers/toml_parser.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sans.state.IStateParser import IStateParser
99
from sans.user_file.toml_parsers.toml_reader import TomlReader
1010
from sans.user_file.toml_parsers.toml_v1_parser import TomlV1Parser
11+
from sans.user_file.toml_parsers.toml_v2_parser import TomlV2Parser
1112

1213

1314
class TomlParser(object):
@@ -33,5 +34,7 @@ def get_versioned_parser(toml_dict, file_information) -> IStateParser:
3334
version = toml_dict["toml_file_version"]
3435
if version == 1:
3536
return TomlV1Parser(toml_dict, file_information=file_information)
37+
if version == 2:
38+
return TomlV2Parser(toml_dict, file_information=file_information)
3639
else:
3740
raise NotImplementedError("Version {0} of the SANS Toml Format is not supported".format(version))

‎scripts/SANS/sans/user_file/toml_parsers/toml_v1_parser.py

+11-73
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@
55
# & Institut Laue - Langevin
66
# SPDX - License - Identifier: GPL - 3.0 +
77

8-
from typing import Optional
9-
108
from sans.common.enums import SANSInstrument, ReductionMode, DetectorType, RangeStepType, FitModeForMerge, DataType, FitType, RebinType
119
from sans.common.general_functions import get_bank_for_spectrum_number, get_detector_types_from_instrument
12-
from sans.state.IStateParser import IStateParser
1310
from sans.state.StateObjects.StateAdjustment import StateAdjustment
1411
from sans.state.StateObjects.StateCalculateTransmission import get_calculate_transmission
1512
from sans.state.StateObjects.StateCompatibility import StateCompatibility
@@ -18,7 +15,7 @@
1815
from sans.state.StateObjects.StateMaskDetectors import get_mask_builder, StateMaskDetectors
1916
from sans.state.StateObjects.StateMoveDetectors import get_move_builder
2017
from sans.state.StateObjects.StateNormalizeToMonitor import get_normalize_to_monitor_builder
21-
from sans.state.StateObjects.StatePolarization import StatePolarization, StateComponent, StateFilter, StateField
18+
from sans.state.StateObjects.StatePolarization import StatePolarization
2219
from sans.state.StateObjects.StateReductionMode import StateReductionMode
2320
from sans.state.StateObjects.StateSave import StateSave
2421
from sans.state.StateObjects.StateScale import StateScale
@@ -28,22 +25,18 @@
2825
from sans.user_file.parser_helpers.toml_parser_impl_base import TomlParserImplBase
2926
from sans.user_file.parser_helpers.wavelength_parser import DuplicateWavelengthStates, WavelengthTomlParser
3027
from sans.user_file.toml_parsers.toml_v1_schema import TomlSchemaV1Validator
28+
from sans.user_file.toml_parsers.toml_base_parser import TomlParserBase
3129

3230

33-
class TomlV1Parser(IStateParser):
31+
class TomlV1Parser(TomlParserBase):
3432
def __init__(self, dict_to_parse, file_information, schema_validator=None):
35-
self._validator = schema_validator if schema_validator else TomlSchemaV1Validator(dict_to_parse)
36-
self._validator.validate()
37-
38-
self._implementation = None
39-
data_info = self.get_state_data(file_information)
40-
self._implementation = self._get_impl(dict_to_parse, data_info)
41-
self._implementation.parse_all()
33+
validator = schema_validator if schema_validator else TomlSchemaV1Validator(dict_to_parse)
34+
super(TomlV1Parser, self).__init__(dict_to_parse, file_information, validator)
4235

4336
@staticmethod
4437
def _get_impl(*args):
4538
# Wrapper which can replaced with a mock
46-
return _TomlV1ParserImpl(*args)
39+
return TomlV1ParserImpl(*args)
4740

4841
def get_state_data(self, file_information):
4942
state_data = super().get_state_data(file_information)
@@ -76,8 +69,9 @@ def get_state_normalize_to_monitor(self, _):
7669
def get_state_reduction_mode(self):
7770
return self._implementation.reduction_mode
7871

79-
def get_state_polarization(self) -> Optional[StatePolarization]:
80-
return self._implementation.polarization
72+
def get_state_polarization(self) -> StatePolarization:
73+
# Not supported by TOML V1, but we return a blank one to keep the parsing results consistent.
74+
return StatePolarization()
8175

8276
def get_state_save(self):
8377
return StateSave()
@@ -98,9 +92,9 @@ def get_state_wavelength_and_pixel_adjustment(self):
9892
return self._implementation.wavelength_and_pixel
9993

10094

101-
class _TomlV1ParserImpl(TomlParserImplBase):
95+
class TomlV1ParserImpl(TomlParserImplBase):
10296
def __init__(self, input_dict, data_info: StateData):
103-
super(_TomlV1ParserImpl, self).__init__(toml_dict=input_dict)
97+
super(TomlV1ParserImpl, self).__init__(toml_dict=input_dict)
10498
# Always take the instrument from the TOML file rather than guessing in the new parser
10599
data_info.instrument = self.instrument
106100
self._create_state_objs(data_info=data_info)
@@ -118,7 +112,6 @@ def parse_all(self):
118112
self._parse_transmission()
119113
self._parse_transmission_roi()
120114
self._parse_transmission_fitting()
121-
self._parse_polarization()
122115

123116
@property
124117
def instrument(self):
@@ -141,7 +134,6 @@ def _create_state_objs(self, data_info):
141134
self.scale = StateScale()
142135
self.wavelength = StateWavelength()
143136
self.wavelength_and_pixel = get_wavelength_and_pixel_adjustment_builder(data_info=data_info).build()
144-
self.polarization = StatePolarization()
145137

146138
# Ensure they are linked up correctly
147139
self.adjustment.calculate_transmission = self.calculate_transmission
@@ -519,60 +511,6 @@ def _parse_mask(self):
519511
if "stop" in phi_mask:
520512
self.mask.phi_max = phi_mask["stop"]
521513

522-
def _parse_polarization(self):
523-
polarization_dict = self.get_val("polarization")
524-
if polarization_dict is None:
525-
return
526-
self.polarization.flipper_configuration = self.get_val("flipper_configuration", polarization_dict)
527-
self.polarization.spin_configuration = self.get_val("spin_configuration", polarization_dict)
528-
flipper_dicts = self.get_val("flipper", polarization_dict)
529-
if flipper_dicts:
530-
for flipper_dict in flipper_dicts.values():
531-
self.polarization.flippers.append(self._parse_component(flipper_dict))
532-
self.polarization.polarizer = self._parse_filter(self.get_val("polarizer", polarization_dict))
533-
self.polarization.analyzer = self._parse_filter(self.get_val("analyzer", polarization_dict))
534-
self.polarization.magnetic_field = self._parse_field(self.get_val("magnetic_field", polarization_dict))
535-
self.polarization.electric_field = self._parse_field(self.get_val("electric_field", polarization_dict))
536-
self.polarization.validate()
537-
538-
def _parse_component(self, component_dict: dict) -> StateComponent:
539-
component_state = StateComponent()
540-
if component_dict is None:
541-
return component_state
542-
component_state.idf_component_name = self.get_val("idf_component_name", component_dict)
543-
component_state.device_name = self.get_val("device_name", component_dict)
544-
component_state.device_type = self.get_val("device_type", component_dict)
545-
location_dict = self.get_val("location", component_dict)
546-
if location_dict:
547-
component_state.location_x = self.get_val("x", location_dict)
548-
component_state.location_y = self.get_val("y", location_dict)
549-
component_state.location_z = self.get_val("z", location_dict)
550-
component_state.transmission = self.get_val("transmission", component_dict)
551-
component_state.efficiency = self.get_val("efficiency", component_dict)
552-
return component_state
553-
554-
def _parse_filter(self, filter_dict: dict) -> StateFilter:
555-
if filter_dict is None:
556-
return StateFilter()
557-
filter_state = self._parse_component(filter_dict)
558-
filter_state.__class__ = StateFilter
559-
filter_state.cell_length = self.get_val("cell_length", filter_dict)
560-
filter_state.gas_pressure = self.get_val("gas_pressure", filter_dict)
561-
return filter_state
562-
563-
def _parse_field(self, field_dict: dict) -> StateField:
564-
field_state = StateField()
565-
if field_dict is None:
566-
return field_state
567-
field_state.sample_strength_log = self.get_val("sample_strength_log", field_dict)
568-
direction_dict = self.get_val("sample_direction", field_dict)
569-
if direction_dict:
570-
field_state.sample_direction_a = self.get_val("a", direction_dict)
571-
field_state.sample_direction_p = self.get_val("p", direction_dict)
572-
field_state.sample_direction_d = self.get_val("d", direction_dict)
573-
field_state.sample_direction_log = self.get_val("sample_direction_log", field_dict)
574-
return field_state
575-
576514
@staticmethod
577515
def _get_1d_min_max(one_d_binning: str):
578516
# TODO: We have to do some special parsing for this type on behalf of the sans codebase

‎scripts/SANS/sans/user_file/toml_parsers/toml_v1_schema.py

+4-81
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,19 @@
44
# NScD Oak Ridge National Laboratory, European Spallation Source
55
# & Institut Laue - Langevin
66
# SPDX - License - Identifier: GPL - 3.0 +
7-
import re
87

8+
from sans.user_file.toml_parsers.toml_base_schema import TomlSchemaValidator
99

10-
class TomlValidationError(Exception):
11-
# A key error would be more appropriate, but there is special
12-
# handling on KeyError which escapes any new lines making it un-readable
13-
pass
1410

15-
16-
class TomlSchemaV1Validator(object):
11+
class TomlSchemaV1Validator(TomlSchemaValidator):
1712
# As of the current TOML release there is no way to validate a schema so
1813
# we must provide an implementation
1914

20-
# Note : To future devs, if we have a V2 schema a lot of this class could
21-
# be split into a SchemaValidator and an inheriting V1 and V2 schema
22-
# would override the reference schema with the new one
23-
2415
def __init__(self, dict_to_validate):
25-
self._expected_list = self._build_nested_keys(self._reference_schema())
26-
self._to_validate_list = self._build_nested_keys(dict_to_validate)
27-
28-
def validate(self):
29-
self._to_validate_list = filter(lambda s: not s.startswith("metadata"), self._to_validate_list)
30-
unrecognised = set(self._to_validate_list).difference(self._expected_list)
31-
32-
if not unrecognised:
33-
return
34-
35-
# Build any with wildcards
36-
wildcard_matchers = [re.compile(s) for s in self._expected_list if "*" in s]
37-
# Remove anything which matches any the regex wildcards
38-
unrecognised = [s for s in unrecognised if not any(wild_matcher.match(s) for wild_matcher in wildcard_matchers)]
39-
40-
if len(unrecognised) > 0:
41-
err = "The following keys were not recognised:\n"
42-
err += "".join("{0} \n".format(k) for k in unrecognised)
43-
raise TomlValidationError(err)
16+
super(TomlSchemaV1Validator, self).__init__(dict_to_validate)
4417

4518
@staticmethod
46-
def _reference_schema():
19+
def reference_schema():
4720
"""
4821
Returns a dictionary layout of all supported keys
4922
:return: Dictionary containing all keys, and values set to None
@@ -137,32 +110,6 @@ def _reference_schema():
137110
},
138111
}
139112

140-
component_keys = {
141-
"idf_component_name": None,
142-
"device_name": None,
143-
"device_type": None,
144-
"location": {"x", "y", "z"},
145-
"transmission": None,
146-
"efficiency": None,
147-
"empty_cell": None,
148-
"initial_polarization": None,
149-
}
150-
filter_keys = dict(component_keys, **{"cell_length": None, "gas_pressure": None})
151-
field_keys = {
152-
"sample_strength_log": None,
153-
"sample_direction": {"a", "p", "d"},
154-
"sample_direction_log": None,
155-
}
156-
polarization_keys = {
157-
"flipper_configuration": None,
158-
"spin_configuration": None,
159-
"flipper": {"*": component_keys},
160-
"polarizer": filter_keys,
161-
"analyzer": filter_keys,
162-
"magnetic_field": field_keys,
163-
"electric_field": field_keys,
164-
}
165-
166113
return {
167114
"toml_file_version": None,
168115
"binning": binning_keys,
@@ -175,28 +122,4 @@ def _reference_schema():
175122
"q_resolution": q_resolution_keys,
176123
"reduction": reduction_keys,
177124
"transmission": transmission_keys,
178-
"polarization": polarization_keys,
179125
}
180-
181-
@staticmethod
182-
def _build_nested_keys(d, path="", current_out=None):
183-
if not current_out:
184-
current_out = []
185-
186-
def make_path(current_path, new_key):
187-
return current_path + "." + new_key if current_path else new_key
188-
189-
for key, v in d.items():
190-
new_path = make_path(path, key)
191-
if isinstance(v, dict):
192-
# Recurse into dict
193-
current_out = TomlSchemaV1Validator._build_nested_keys(v, new_path, current_out)
194-
elif isinstance(v, set):
195-
# Pack all in from the set of names
196-
for name in v:
197-
current_out.append(make_path(new_path, name))
198-
else:
199-
# This means its a value type with nothing special, so keep name
200-
current_out.append(new_path)
201-
202-
return current_out
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Mantid Repository : https://github.com/mantidproject/mantid
2+
#
3+
# Copyright © 2025 ISIS Rutherford Appleton Laboratory UKRI,
4+
# NScD Oak Ridge National Laboratory, European Spallation Source
5+
# & Institut Laue - Langevin
6+
# SPDX - License - Identifier: GPL - 3.0 +
7+
8+
from sans.state.StateObjects.StateData import StateData
9+
from sans.state.StateObjects.StatePolarization import StatePolarization, StateComponent, StateFilter, StateField
10+
from sans.user_file.toml_parsers.toml_v1_parser import TomlV1Parser, TomlV1ParserImpl
11+
from sans.user_file.toml_parsers.toml_v2_schema import TomlSchemaV2Validator
12+
13+
14+
class TomlV2Parser(TomlV1Parser):
15+
def __init__(self, dict_to_parse, file_information, schema_validator=None):
16+
validator = schema_validator if schema_validator else TomlSchemaV2Validator(dict_to_parse)
17+
super(TomlV2Parser, self).__init__(dict_to_parse, file_information, validator)
18+
19+
@staticmethod
20+
def _get_impl(*args):
21+
# Wrapper which can replaced with a mock
22+
return TomlV2ParserImpl(*args)
23+
24+
def get_state_polarization(self) -> StatePolarization:
25+
return self._implementation.polarization
26+
27+
28+
class TomlV2ParserImpl(TomlV1ParserImpl):
29+
def __init__(self, input_dict, data_info: StateData):
30+
super(TomlV2ParserImpl, self).__init__(input_dict, data_info)
31+
32+
def parse_all(self):
33+
super().parse_all()
34+
self._parse_polarization()
35+
36+
def _create_state_objs(self, data_info):
37+
super()._create_state_objs(data_info)
38+
self.polarization = StatePolarization()
39+
40+
def _parse_polarization(self):
41+
polarization_dict = self.get_val("polarization")
42+
if polarization_dict is None:
43+
return
44+
self.polarization.flipper_configuration = self.get_val("flipper_configuration", polarization_dict)
45+
self.polarization.spin_configuration = self.get_val("spin_configuration", polarization_dict)
46+
flipper_dicts = self.get_val("flipper", polarization_dict)
47+
if flipper_dicts:
48+
for flipper_dict in flipper_dicts.values():
49+
self.polarization.flippers.append(self._parse_component(flipper_dict))
50+
self.polarization.polarizer = self._parse_filter(self.get_val("polarizer", polarization_dict))
51+
self.polarization.analyzer = self._parse_filter(self.get_val("analyzer", polarization_dict))
52+
self.polarization.magnetic_field = self._parse_field(self.get_val("magnetic_field", polarization_dict))
53+
self.polarization.electric_field = self._parse_field(self.get_val("electric_field", polarization_dict))
54+
self.polarization.validate()
55+
56+
def _parse_component(self, component_dict: dict) -> StateComponent:
57+
component_state = StateComponent()
58+
if component_dict is None:
59+
return component_state
60+
component_state.idf_component_name = self.get_val("idf_component_name", component_dict)
61+
component_state.device_name = self.get_val("device_name", component_dict)
62+
component_state.device_type = self.get_val("device_type", component_dict)
63+
location_dict = self.get_val("location", component_dict)
64+
if location_dict:
65+
component_state.location_x = self.get_val("x", location_dict)
66+
component_state.location_y = self.get_val("y", location_dict)
67+
component_state.location_z = self.get_val("z", location_dict)
68+
component_state.transmission = self.get_val("transmission", component_dict)
69+
component_state.efficiency = self.get_val("efficiency", component_dict)
70+
return component_state
71+
72+
def _parse_filter(self, filter_dict: dict) -> StateFilter:
73+
if filter_dict is None:
74+
return StateFilter()
75+
filter_state = self._parse_component(filter_dict)
76+
filter_state.__class__ = StateFilter
77+
filter_state.cell_length = self.get_val("cell_length", filter_dict)
78+
filter_state.gas_pressure = self.get_val("gas_pressure", filter_dict)
79+
filter_state.empty_cell = self.get_val("empty_cell", filter_dict)
80+
filter_state.initial_polarization = self.get_val("initial_polarization", filter_dict)
81+
return filter_state
82+
83+
def _parse_field(self, field_dict: dict) -> StateField:
84+
field_state = StateField()
85+
if field_dict is None:
86+
return field_state
87+
field_state.sample_strength_log = self.get_val("sample_strength_log", field_dict)
88+
direction_dict = self.get_val("sample_direction", field_dict)
89+
if direction_dict:
90+
field_state.sample_direction_a = self.get_val("a", direction_dict)
91+
field_state.sample_direction_p = self.get_val("p", direction_dict)
92+
field_state.sample_direction_d = self.get_val("d", direction_dict)
93+
field_state.sample_direction_log = self.get_val("sample_direction_log", field_dict)
94+
return field_state
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Mantid Repository : https://github.com/mantidproject/mantid
2+
#
3+
# Copyright © 2025 ISIS Rutherford Appleton Laboratory UKRI,
4+
# NScD Oak Ridge National Laboratory, European Spallation Source
5+
# & Institut Laue - Langevin
6+
# SPDX - License - Identifier: GPL - 3.0 +
7+
8+
from sans.user_file.toml_parsers.toml_base_schema import TomlSchemaValidator
9+
from sans.user_file.toml_parsers.toml_v1_schema import TomlSchemaV1Validator
10+
11+
12+
class TomlSchemaV2Validator(TomlSchemaValidator):
13+
# As of the current TOML release there is no way to validate a schema so
14+
# we must provide an implementation
15+
16+
def __init__(self, dict_to_validate):
17+
super(TomlSchemaV2Validator, self).__init__(dict_to_validate)
18+
19+
@staticmethod
20+
def reference_schema():
21+
"""
22+
Returns a dictionary layout of all supported keys. Extends from the V1 Schema.
23+
:return: Dictionary containing all keys, and values set to None
24+
"""
25+
component_keys = {
26+
"idf_component_name": None,
27+
"device_name": None,
28+
"device_type": None,
29+
"location": {"x", "y", "z"},
30+
"transmission": None,
31+
"efficiency": None,
32+
}
33+
filter_keys = dict(
34+
component_keys,
35+
**{
36+
"cell_length": None,
37+
"gas_pressure": None,
38+
"empty_cell": None,
39+
"initial_polarization": None,
40+
},
41+
)
42+
field_keys = {
43+
"sample_strength_log": None,
44+
"sample_direction": {"a", "p", "d"},
45+
"sample_direction_log": None,
46+
}
47+
polarization_keys = {
48+
"flipper_configuration": None,
49+
"spin_configuration": None,
50+
"flipper": {"*": component_keys},
51+
"polarizer": filter_keys,
52+
"analyzer": filter_keys,
53+
"magnetic_field": field_keys,
54+
"electric_field": field_keys,
55+
}
56+
57+
return dict(TomlSchemaV1Validator.reference_schema(), **{"polarization": polarization_keys})

‎scripts/test/SANS/gui_logic/test_run_tab_presenter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sans.test_helper.mock_objects import create_mock_view
2424
from sans.test_helper.user_file_test_helper import create_user_file, sample_user_file
2525
from ui.sans_isis.sans_gui_observable import SansGuiObservable
26-
from sans.user_file.toml_parsers.toml_v1_schema import TomlValidationError
26+
from sans.user_file.toml_parsers.toml_base_schema import TomlValidationError
2727

2828
BATCH_FILE_TEST_CONTENT_1 = [
2929
RowEntries(sample_scatter=1, sample_transmission=2, sample_direct=3, output_name="test_file", user_file="user_test_file"),

‎scripts/test/SANS/user_file/toml_parsers/CMakeLists.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
set(TEST_PY_FILES toml_parser_test.py toml_v1_parser_test.py toml_v1_schema_test.py)
1+
set(TEST_PY_FILES toml_parser_test.py toml_v1_parser_test.py toml_v1_schema_test.py toml_v2_parser_test.py
2+
toml_v2_schema_test.py
3+
)
24

35
add_subdirectory(parser_helpers)
46

‎scripts/test/SANS/user_file/toml_parsers/toml_parser_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ def test_returns_v1_parser(self):
2828
# Check correct params were forwarded on
2929
mocked_import.assert_called_once_with(test_dict, file_information=mocked_file_info)
3030

31+
def test_returns_v2_parser(self):
32+
test_dict = {"toml_file_version": 2}
33+
34+
parser = TomlParser(toml_reader=self.get_mocked_reader(test_dict))
35+
mocked_file_info = mock.NonCallableMock()
36+
with mock.patch("sans.user_file.toml_parsers.toml_parser.TomlV2Parser") as mocked_import:
37+
parser_version = parser.get_toml_parser(toml_file_path=mock.NonCallableMock, file_information=mocked_file_info)
38+
self.assertEqual(mocked_import.return_value, parser_version)
39+
# Check correct params were forwarded on
40+
mocked_import.assert_called_once_with(test_dict, file_information=mocked_file_info)
41+
3142
def test_throws_for_unknown_version(self):
3243
test_dict = {"toml_file_version": 100}
3344
parser = TomlParser(toml_reader=self.get_mocked_reader(test_dict))

‎scripts/test/SANS/user_file/toml_parsers/toml_v1_parser_test.py

+5-152
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55
# & Institut Laue - Langevin
66
# SPDX - License - Identifier: GPL - 3.0 +
77
import unittest
8-
from typing import List, Dict
98
from unittest import mock
109

1110
from sans.common.enums import (
1211
SANSInstrument,
13-
SANSFacility,
1412
DetectorType,
1513
ReductionMode,
1614
RangeStepType,
@@ -19,43 +17,17 @@
1917
FitType,
2018
RebinType,
2119
)
22-
from sans.state.StateObjects.StateData import get_data_builder
2320
from sans.state.StateObjects.StateMaskDetectors import StateMaskDetectors, StateMask
24-
from sans.state.StateObjects.StatePolarization import StatePolarization
25-
from sans.test_helper.file_information_mock import SANSFileInformationMock
2621
from sans.user_file.parser_helpers.toml_parser_impl_base import MissingMandatoryParam
2722
from sans.user_file.toml_parsers.toml_v1_parser import TomlV1Parser
23+
from sans.test_helper.toml_parser_test_helpers import setup_parser_dict
2824

2925

3026
class TomlV1ParserTest(unittest.TestCase):
31-
@staticmethod
32-
def _get_mock_data_info():
33-
# TODO I really really dislike having to do this in a test, but
34-
# TODO de-coupling StateData is required to avoid it
35-
file_information = SANSFileInformationMock(instrument=SANSInstrument.SANS2D, run_number=22024)
36-
data_builder = get_data_builder(SANSFacility.ISIS, file_information)
37-
data_builder.set_sample_scatter("SANS2D00022024")
38-
data_builder.set_sample_scatter_period(3)
39-
return data_builder.build()
40-
41-
def _setup_parser(self, dict_vals) -> TomlV1Parser:
42-
def _add_missing_mandatory_key(dict_to_check: Dict, key_path: List[str], replacement_val):
43-
_dict = dict_to_check
44-
for key in key_path[0:-1]:
45-
if key not in _dict:
46-
_dict[key] = {}
47-
_dict = _dict[key]
48-
49-
if key_path[-1] not in _dict:
50-
_dict[key_path[-1]] = replacement_val # Add in child value
51-
return dict_to_check
52-
53-
self._mocked_data_info = self._get_mock_data_info()
54-
# instrument key needs to generally be present
55-
dict_vals = _add_missing_mandatory_key(dict_vals, ["instrument", "name"], "LOQ")
56-
dict_vals = _add_missing_mandatory_key(dict_vals, ["detector", "configuration", "selected_detector"], "rear")
57-
58-
return TomlV1Parser(dict_vals, file_information=None)
27+
def _setup_parser(self, dict_vals):
28+
setup_dict, mocked_data_info = setup_parser_dict(dict_vals)
29+
self._mocked_data_info = mocked_data_info
30+
return TomlV1Parser(setup_dict, file_information=None)
5931

6032
def test_instrument(self):
6133
parser = self._setup_parser(dict_vals={"instrument": {"name": SANSInstrument.SANS2D.value}})
@@ -696,125 +668,6 @@ def test_parse_mask(self):
696668
self.assertEqual(102, norm_state.prompt_peak_correction_max)
697669
self.assertTrue(norm_state.prompt_peak_correction_enabled)
698670

699-
def test_parse_polarization(self):
700-
top_level_dict = {"polarization": {"flipper_configuration": "00,11,01,10", "spin_configuration": "-1-1,-1+1,+1-1,+1+1"}}
701-
parser_result = self._setup_parser(top_level_dict)
702-
polarization_state = parser_result.get_state_polarization()
703-
704-
self.assertIsInstance(polarization_state, StatePolarization)
705-
self.assertEqual("00,11,01,10", polarization_state.flipper_configuration)
706-
self.assertEqual("-1-1,-1+1,+1-1,+1+1", polarization_state.spin_configuration)
707-
708-
def test_parse_flippers(self):
709-
top_level_dict = {
710-
"polarization": {
711-
"flipper": {
712-
"polarizing": {
713-
"idf_component_name": "name_in_IDF",
714-
"device_name": "flipper1",
715-
"device_type": "coil",
716-
"location": {"x": 1.17, "y": 0.05, "z": 0.045},
717-
"transmission": "trans_ws",
718-
"efficiency": "eff_ws",
719-
},
720-
"analyzing": {
721-
"idf_component_name": "name_in_IDF_a",
722-
"device_name": "flipper2",
723-
"device_type": "coil",
724-
"location": {"x": 2.17, "y": 0.05, "z": 0.045},
725-
"transmission": "trans_ws",
726-
"efficiency": "eff_ws",
727-
},
728-
}
729-
}
730-
}
731-
parser_result = self._setup_parser(top_level_dict)
732-
polarization_state = parser_result.get_state_polarization()
733-
flippers = polarization_state.flippers
734-
self.assertEqual(2, len(flippers))
735-
self.assertEqual("flipper1", flippers[0].device_name)
736-
self.assertEqual("flipper2", flippers[1].device_name)
737-
self.assertEqual(1.17, flippers[0].location_x)
738-
self.assertEqual(0.05, flippers[0].location_y)
739-
self.assertEqual(0.045, flippers[0].location_z)
740-
self.assertEqual(2.17, flippers[1].location_x)
741-
self.assertEqual("name_in_IDF", flippers[0].idf_component_name)
742-
self.assertEqual("coil", flippers[0].device_type)
743-
self.assertEqual("trans_ws", flippers[0].transmission)
744-
self.assertEqual("eff_ws", flippers[0].efficiency)
745-
746-
def test_parse_polarizer_and_analyzer(self):
747-
top_level_dict = {
748-
"polarization": {
749-
"polarizer": {
750-
"idf_component_name": "name_in_IDF_pol",
751-
"device_name": "sm-polarizer",
752-
"device_type": "coil",
753-
"location": {"x": 1.17, "y": 0.05, "z": 0.045},
754-
"transmission": "trans_ws",
755-
"efficiency": "eff_ws",
756-
"cell_length": 0.005,
757-
"gas_pressure": 5,
758-
},
759-
"analyzer": {
760-
"idf_component_name": "name_in_IDF_ana",
761-
"device_name": "3He-analyzer",
762-
"device_type": "coil",
763-
"location": {"x": 2.17, "y": 0.05, "z": 0.045},
764-
"cell_length": 0.006,
765-
"gas_pressure": 6,
766-
"transmission": "trans_ws",
767-
"efficiency": "eff_ws",
768-
},
769-
}
770-
}
771-
parser_result = self._setup_parser(top_level_dict)
772-
polarization_state = parser_result.get_state_polarization()
773-
polarizer_state = polarization_state.polarizer
774-
analyzer_state = polarization_state.analyzer
775-
self.assertEqual(0.006, analyzer_state.cell_length)
776-
self.assertEqual(0.005, polarizer_state.cell_length)
777-
self.assertEqual(6, analyzer_state.gas_pressure)
778-
self.assertEqual(5, polarizer_state.gas_pressure)
779-
self.assertEqual("sm-polarizer", polarizer_state.device_name)
780-
self.assertEqual("3He-analyzer", analyzer_state.device_name)
781-
self.assertEqual(1.17, polarizer_state.location_x)
782-
self.assertEqual(0.05, polarizer_state.location_y)
783-
self.assertEqual(0.045, polarizer_state.location_z)
784-
self.assertEqual(2.17, analyzer_state.location_x)
785-
self.assertEqual("name_in_IDF_pol", polarizer_state.idf_component_name)
786-
self.assertEqual("coil", polarizer_state.device_type)
787-
self.assertEqual("trans_ws", polarizer_state.transmission)
788-
self.assertEqual("eff_ws", polarizer_state.efficiency)
789-
790-
def test_parse_fields(self):
791-
top_level_dict = {
792-
"polarization": {
793-
"magnetic_field": {
794-
"sample_strength_log": "nameoflog",
795-
"sample_direction": {"a": 0, "p": 2.3, "d": 0.002},
796-
},
797-
"electric_field": {
798-
"sample_strength_log": "nameofotherlog",
799-
"sample_direction_log": "nameofanotherlog",
800-
},
801-
}
802-
}
803-
parser_result = self._setup_parser(top_level_dict)
804-
polarization_state = parser_result.get_state_polarization()
805-
electric_state = polarization_state.electric_field
806-
magnetic_state = polarization_state.magnetic_field
807-
self.assertEqual("nameoflog", magnetic_state.sample_strength_log)
808-
self.assertEqual(0, magnetic_state.sample_direction_a)
809-
self.assertEqual(2.3, magnetic_state.sample_direction_p)
810-
self.assertEqual(0.002, magnetic_state.sample_direction_d)
811-
self.assertIsNone(magnetic_state.sample_direction_log)
812-
self.assertEqual("nameofotherlog", electric_state.sample_strength_log)
813-
self.assertEqual("nameofanotherlog", electric_state.sample_direction_log)
814-
self.assertIsNone(electric_state.sample_direction_a)
815-
self.assertIsNone(electric_state.sample_direction_p)
816-
self.assertIsNone(electric_state.sample_direction_d)
817-
818671

819672
if __name__ == "__main__":
820673
unittest.main()

‎scripts/test/SANS/user_file/toml_parsers/toml_v1_schema_test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import unittest
88

99
from unittest import mock
10-
from sans.user_file.toml_parsers.toml_v1_schema import TomlSchemaV1Validator, TomlValidationError
10+
from sans.user_file.toml_parsers.toml_v1_schema import TomlSchemaV1Validator
11+
from sans.user_file.toml_parsers.toml_base_schema import TomlValidationError
1112

1213

1314
class SchemaV1ValidatorTest(unittest.TestCase):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Mantid Repository : https://github.com/mantidproject/mantid
2+
#
3+
# Copyright © 2025 ISIS Rutherford Appleton Laboratory UKRI,
4+
# NScD Oak Ridge National Laboratory, European Spallation Source
5+
# & Institut Laue - Langevin
6+
# SPDX - License - Identifier: GPL - 3.0 +
7+
8+
import unittest
9+
10+
from sans.common.enums import SANSInstrument
11+
from sans.state.StateObjects.StatePolarization import StatePolarization
12+
from sans.user_file.toml_parsers.toml_v2_parser import TomlV2Parser
13+
from sans.test_helper.toml_parser_test_helpers import setup_parser_dict
14+
15+
16+
class TomlV2ParserTest(unittest.TestCase):
17+
def _setup_parser(self, dict_vals):
18+
setup_dict, mocked_data_info = setup_parser_dict(dict_vals)
19+
self._mocked_data_info = mocked_data_info
20+
return TomlV2Parser(setup_dict, file_information=None)
21+
22+
# This test is a duplication of the one in the V1 parser test. It's to check the inheritance is working.
23+
def test_instrument(self):
24+
parser = self._setup_parser(dict_vals={"instrument": {"name": SANSInstrument.SANS2D.value}})
25+
inst = parser._implementation.instrument
26+
self.assertTrue(inst is SANSInstrument.SANS2D, msg="Got %r instead" % inst)
27+
28+
def test_parse_polarization(self):
29+
top_level_dict = {"polarization": {"flipper_configuration": "00,11,01,10", "spin_configuration": "-1-1,-1+1,+1-1,+1+1"}}
30+
parser_result = self._setup_parser(top_level_dict)
31+
polarization_state = parser_result.get_state_polarization()
32+
33+
self.assertIsInstance(polarization_state, StatePolarization)
34+
self.assertEqual("00,11,01,10", polarization_state.flipper_configuration)
35+
self.assertEqual("-1-1,-1+1,+1-1,+1+1", polarization_state.spin_configuration)
36+
37+
def test_parse_flippers(self):
38+
top_level_dict = {
39+
"polarization": {
40+
"flipper": {
41+
"polarizing": {
42+
"idf_component_name": "name_in_IDF",
43+
"device_name": "flipper1",
44+
"device_type": "coil",
45+
"location": {"x": 1.17, "y": 0.05, "z": 0.045},
46+
"transmission": "trans_ws",
47+
"efficiency": "eff_ws",
48+
},
49+
"analyzing": {
50+
"idf_component_name": "name_in_IDF_a",
51+
"device_name": "flipper2",
52+
"device_type": "coil",
53+
"location": {"x": 2.17, "y": 0.05, "z": 0.045},
54+
"transmission": "trans_ws",
55+
"efficiency": "eff_ws",
56+
},
57+
}
58+
}
59+
}
60+
parser_result = self._setup_parser(top_level_dict)
61+
polarization_state = parser_result.get_state_polarization()
62+
flippers = polarization_state.flippers
63+
self.assertEqual(2, len(flippers))
64+
self.assertEqual("flipper1", flippers[0].device_name)
65+
self.assertEqual("flipper2", flippers[1].device_name)
66+
self.assertEqual(1.17, flippers[0].location_x)
67+
self.assertEqual(0.05, flippers[0].location_y)
68+
self.assertEqual(0.045, flippers[0].location_z)
69+
self.assertEqual(2.17, flippers[1].location_x)
70+
self.assertEqual("name_in_IDF", flippers[0].idf_component_name)
71+
self.assertEqual("coil", flippers[0].device_type)
72+
self.assertEqual("trans_ws", flippers[0].transmission)
73+
self.assertEqual("eff_ws", flippers[0].efficiency)
74+
75+
def test_parse_polarizer_and_analyzer(self):
76+
top_level_dict = {
77+
"polarization": {
78+
"polarizer": {
79+
"idf_component_name": "name_in_IDF_pol",
80+
"device_name": "sm-polarizer",
81+
"device_type": "coil",
82+
"location": {"x": 1.17, "y": 0.05, "z": 0.045},
83+
"transmission": "trans_ws",
84+
"efficiency": "eff_ws",
85+
"cell_length": 0.005,
86+
"gas_pressure": 5,
87+
},
88+
"analyzer": {
89+
"idf_component_name": "name_in_IDF_ana",
90+
"device_name": "3He-analyzer",
91+
"device_type": "coil",
92+
"location": {"x": 2.17, "y": 0.05, "z": 0.045},
93+
"cell_length": 0.006,
94+
"gas_pressure": 6,
95+
"transmission": "trans_ws",
96+
"efficiency": "eff_ws",
97+
},
98+
}
99+
}
100+
parser_result = self._setup_parser(top_level_dict)
101+
polarization_state = parser_result.get_state_polarization()
102+
polarizer_state = polarization_state.polarizer
103+
analyzer_state = polarization_state.analyzer
104+
self.assertEqual(0.006, analyzer_state.cell_length)
105+
self.assertEqual(0.005, polarizer_state.cell_length)
106+
self.assertEqual(6, analyzer_state.gas_pressure)
107+
self.assertEqual(5, polarizer_state.gas_pressure)
108+
self.assertEqual("sm-polarizer", polarizer_state.device_name)
109+
self.assertEqual("3He-analyzer", analyzer_state.device_name)
110+
self.assertEqual(1.17, polarizer_state.location_x)
111+
self.assertEqual(0.05, polarizer_state.location_y)
112+
self.assertEqual(0.045, polarizer_state.location_z)
113+
self.assertEqual(2.17, analyzer_state.location_x)
114+
self.assertEqual("name_in_IDF_pol", polarizer_state.idf_component_name)
115+
self.assertEqual("coil", polarizer_state.device_type)
116+
self.assertEqual("trans_ws", polarizer_state.transmission)
117+
self.assertEqual("eff_ws", polarizer_state.efficiency)
118+
119+
def test_parse_fields(self):
120+
top_level_dict = {
121+
"polarization": {
122+
"magnetic_field": {
123+
"sample_strength_log": "nameoflog",
124+
"sample_direction": {"a": 0, "p": 2.3, "d": 0.002},
125+
},
126+
"electric_field": {
127+
"sample_strength_log": "nameofotherlog",
128+
"sample_direction_log": "nameofanotherlog",
129+
},
130+
}
131+
}
132+
parser_result = self._setup_parser(top_level_dict)
133+
polarization_state = parser_result.get_state_polarization()
134+
electric_state = polarization_state.electric_field
135+
magnetic_state = polarization_state.magnetic_field
136+
self.assertEqual("nameoflog", magnetic_state.sample_strength_log)
137+
self.assertEqual(0, magnetic_state.sample_direction_a)
138+
self.assertEqual(2.3, magnetic_state.sample_direction_p)
139+
self.assertEqual(0.002, magnetic_state.sample_direction_d)
140+
self.assertIsNone(magnetic_state.sample_direction_log)
141+
self.assertEqual("nameofotherlog", electric_state.sample_strength_log)
142+
self.assertEqual("nameofanotherlog", electric_state.sample_direction_log)
143+
self.assertIsNone(electric_state.sample_direction_a)
144+
self.assertIsNone(electric_state.sample_direction_p)
145+
self.assertIsNone(electric_state.sample_direction_d)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Mantid Repository : https://github.com/mantidproject/mantid
2+
#
3+
# Copyright © 2025 ISIS Rutherford Appleton Laboratory UKRI,
4+
# NScD Oak Ridge National Laboratory, European Spallation Source
5+
# & Institut Laue - Langevin
6+
# SPDX - License - Identifier: GPL - 3.0 +
7+
import unittest
8+
9+
from sans.user_file.toml_parsers.toml_v2_schema import TomlSchemaV2Validator
10+
from sans.user_file.toml_parsers.toml_base_schema import TomlValidationError
11+
12+
13+
class SchemaV2ValidatorTest(unittest.TestCase):
14+
def test_multiple_flippers_respected(self):
15+
valid_example = {"polarization": {"flipper": {"F0": {"device_type": "coil"}, "F1": {"device_type": "magic"}}}}
16+
invalid_example = {"polarization": {"flipper": {"F0": {"device_type": "coil"}}}, "flipper": {"F1": {"fake_key": "magic"}}}
17+
18+
obj = TomlSchemaV2Validator(valid_example)
19+
self.assertIsNone(obj.validate())
20+
21+
with self.assertRaises(TomlValidationError):
22+
TomlSchemaV2Validator(invalid_example).validate()
23+
24+
def test_duplicate_flippers_fails(self):
25+
invalid_example = {"polarization": {"flipper": {"F0": {"device_type": "coil"}}}, "flipper": {"F0": {"device_type": "magic"}}}
26+
27+
with self.assertRaises(TomlValidationError):
28+
TomlSchemaV2Validator(invalid_example).validate()
29+
30+
31+
if __name__ == "__main__":
32+
unittest.main()

0 commit comments

Comments
 (0)
Please sign in to comment.