Skip to content

Commit 120c0c5

Browse files
committed
Create TOML V2
Given that the new things we're adding to the user file aren't compatible with older versions of mantid, we're bumping up to TOML V2 so that we can keep being backwards compatible. This is mainly to stop new user files containing this information from being loaded by older versions, making it clear that you need to upgrade to use these new features and toml files. RE mantidproject#38524
1 parent 314c210 commit 120c0c5

17 files changed

+513
-310
lines changed

scripts/SANS/sans/gui_logic/presenter/run_tab_presenter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Optional
1919

2020
from mantidqt.utils.observer_pattern import GenericObserver
21-
from sans.user_file.toml_parsers.toml_v1_schema import TomlValidationError
21+
from sans.user_file.toml_parsers.toml_base_schema import TomlValidationError
2222
from ui.sans_isis import SANSSaveOtherWindow
2323
from ui.sans_isis.sans_data_processor_gui import SANSDataProcessorGui
2424
from ui.sans_isis.sans_gui_observable import SansGuiObservable

scripts/SANS/sans/state/StateObjects/StatePolarization.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self):
5858
# Relevant to He3 filters (Polarisers and Analysers)
5959
self.cell_length = None
6060
self.gas_pressure = None
61+
self.empty_cell = None
6162
# Relevant to all polarisers and analysers.
6263
self.initial_polarization = None
6364

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Mantid Repository : https://github.com/mantidproject/mantid
2+
#
3+
# Copyright © 2025 ISIS Rutherford Appleton Laboratory UKRI,
4+
# NScD Oak Ridge National Laboratory, European Spallation Source
5+
# & Institut Laue - Langevin
6+
# SPDX - License - Identifier: GPL - 3.0 +
7+
8+
from typing import List, Dict
9+
10+
from sans.common.enums import SANSInstrument, SANSFacility
11+
from sans.state.StateObjects.StateData import get_data_builder
12+
from sans.state.StateObjects.StateData import StateData
13+
from sans.test_helper.file_information_mock import SANSFileInformationMock
14+
15+
16+
def get_mock_data_info():
17+
# TODO I really really dislike having to do this in a test, but
18+
# TODO de-coupling StateData is required to avoid it
19+
file_information = SANSFileInformationMock(instrument=SANSInstrument.SANS2D, run_number=22024)
20+
data_builder = get_data_builder(SANSFacility.ISIS, file_information)
21+
data_builder.set_sample_scatter("SANS2D00022024")
22+
data_builder.set_sample_scatter_period(3)
23+
return data_builder.build()
24+
25+
26+
def setup_parser_dict(dict_vals) -> tuple[dict, StateData]:
27+
def _add_missing_mandatory_key(dict_to_check: Dict, key_path: List[str], replacement_val):
28+
_dict = dict_to_check
29+
for key in key_path[0:-1]:
30+
if key not in _dict:
31+
_dict[key] = {}
32+
_dict = _dict[key]
33+
34+
if key_path[-1] not in _dict:
35+
_dict[key_path[-1]] = replacement_val # Add in child value
36+
return dict_to_check
37+
38+
mocked_data_info = get_mock_data_info()
39+
# instrument key needs to generally be present
40+
dict_vals = _add_missing_mandatory_key(dict_vals, ["instrument", "name"], "LOQ")
41+
dict_vals = _add_missing_mandatory_key(dict_vals, ["detector", "configuration", "selected_detector"], "rear")
42+
43+
return dict_vals, mocked_data_info
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Mantid Repository : https://github.com/mantidproject/mantid
2+
#
3+
# Copyright © 2025 ISIS Rutherford Appleton Laboratory UKRI,
4+
# NScD Oak Ridge National Laboratory, European Spallation Source
5+
# & Institut Laue - Langevin
6+
# SPDX - License - Identifier: GPL - 3.0 +
7+
8+
from abc import ABCMeta, abstractmethod
9+
10+
from sans.state.IStateParser import IStateParser
11+
from sans.user_file.toml_parsers.toml_base_schema import TomlSchemaValidator
12+
13+
14+
class TomlParserBase(IStateParser, metaclass=ABCMeta):
15+
def __init__(self, dict_to_parse, file_information, schema_validator: TomlSchemaValidator):
16+
self._validator = schema_validator
17+
self._validator.validate()
18+
19+
self._implementation = None
20+
data_info = self.get_state_data(file_information)
21+
self._implementation = self._get_impl(dict_to_parse, data_info)
22+
self._implementation.parse_all()
23+
24+
@staticmethod
25+
@abstractmethod
26+
def _get_impl(*args):
27+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Mantid Repository : https://github.com/mantidproject/mantid
2+
#
3+
# Copyright © 2025 ISIS Rutherford Appleton Laboratory UKRI,
4+
# NScD Oak Ridge National Laboratory, European Spallation Source
5+
# & Institut Laue - Langevin
6+
# SPDX - License - Identifier: GPL - 3.0 +
7+
8+
import re
9+
from abc import abstractmethod, ABCMeta
10+
11+
12+
class TomlValidationError(Exception):
13+
# A key error would be more appropriate, but there is special
14+
# handling on KeyError which escapes any new lines making it un-readable
15+
pass
16+
17+
18+
class TomlSchemaValidator(object, metaclass=ABCMeta):
19+
# As of the current TOML release there is no way to validate a schema so
20+
# we must provide an implementation
21+
22+
def __init__(self, dict_to_validate):
23+
self._expected_list = self._build_nested_keys(self.reference_schema())
24+
self._to_validate_list = self._build_nested_keys(dict_to_validate)
25+
26+
def validate(self):
27+
self._to_validate_list = filter(lambda s: not s.startswith("metadata"), self._to_validate_list)
28+
unrecognised = set(self._to_validate_list).difference(self._expected_list)
29+
30+
if not unrecognised:
31+
return
32+
33+
# Build any with wildcards
34+
wildcard_matchers = [re.compile(s) for s in self._expected_list if "*" in s]
35+
# Remove anything which matches any the regex wildcards
36+
unrecognised = [s for s in unrecognised if not any(wild_matcher.match(s) for wild_matcher in wildcard_matchers)]
37+
38+
if len(unrecognised) > 0:
39+
err = "The following keys were not recognised:\n"
40+
err += "".join("{0} \n".format(k) for k in unrecognised)
41+
raise TomlValidationError(err)
42+
43+
@staticmethod
44+
@abstractmethod
45+
def reference_schema():
46+
"""
47+
Returns a dictionary layout of all supported keys
48+
:return: Dictionary containing all keys, and values set to None
49+
"""
50+
pass
51+
52+
@staticmethod
53+
def _build_nested_keys(d, path="", current_out=None):
54+
if not current_out:
55+
current_out = []
56+
57+
def make_path(current_path, new_key):
58+
return current_path + "." + new_key if current_path else new_key
59+
60+
for key, v in d.items():
61+
new_path = make_path(path, key)
62+
if isinstance(v, dict):
63+
# Recurse into dict
64+
current_out = TomlSchemaValidator._build_nested_keys(v, new_path, current_out)
65+
elif isinstance(v, set):
66+
# Pack all in from the set of names
67+
for name in v:
68+
current_out.append(make_path(new_path, name))
69+
else:
70+
# This means its a value type with nothing special, so keep name
71+
current_out.append(new_path)
72+
73+
return current_out

scripts/SANS/sans/user_file/toml_parsers/toml_parser.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sans.state.IStateParser import IStateParser
99
from sans.user_file.toml_parsers.toml_reader import TomlReader
1010
from sans.user_file.toml_parsers.toml_v1_parser import TomlV1Parser
11+
from sans.user_file.toml_parsers.toml_v2_parser import TomlV2Parser
1112

1213

1314
class TomlParser(object):
@@ -33,5 +34,7 @@ def get_versioned_parser(toml_dict, file_information) -> IStateParser:
3334
version = toml_dict["toml_file_version"]
3435
if version == 1:
3536
return TomlV1Parser(toml_dict, file_information=file_information)
37+
if version == 2:
38+
return TomlV2Parser(toml_dict, file_information=file_information)
3639
else:
3740
raise NotImplementedError("Version {0} of the SANS Toml Format is not supported".format(version))

scripts/SANS/sans/user_file/toml_parsers/toml_v1_parser.py

+11-73
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@
55
# & Institut Laue - Langevin
66
# SPDX - License - Identifier: GPL - 3.0 +
77

8-
from typing import Optional
9-
108
from sans.common.enums import SANSInstrument, ReductionMode, DetectorType, RangeStepType, FitModeForMerge, DataType, FitType, RebinType
119
from sans.common.general_functions import get_bank_for_spectrum_number, get_detector_types_from_instrument
12-
from sans.state.IStateParser import IStateParser
1310
from sans.state.StateObjects.StateAdjustment import StateAdjustment
1411
from sans.state.StateObjects.StateCalculateTransmission import get_calculate_transmission
1512
from sans.state.StateObjects.StateCompatibility import StateCompatibility
@@ -18,7 +15,7 @@
1815
from sans.state.StateObjects.StateMaskDetectors import get_mask_builder, StateMaskDetectors
1916
from sans.state.StateObjects.StateMoveDetectors import get_move_builder
2017
from sans.state.StateObjects.StateNormalizeToMonitor import get_normalize_to_monitor_builder
21-
from sans.state.StateObjects.StatePolarization import StatePolarization, StateComponent, StateFilter, StateField
18+
from sans.state.StateObjects.StatePolarization import StatePolarization
2219
from sans.state.StateObjects.StateReductionMode import StateReductionMode
2320
from sans.state.StateObjects.StateSave import StateSave
2421
from sans.state.StateObjects.StateScale import StateScale
@@ -28,22 +25,18 @@
2825
from sans.user_file.parser_helpers.toml_parser_impl_base import TomlParserImplBase
2926
from sans.user_file.parser_helpers.wavelength_parser import DuplicateWavelengthStates, WavelengthTomlParser
3027
from sans.user_file.toml_parsers.toml_v1_schema import TomlSchemaV1Validator
28+
from sans.user_file.toml_parsers.toml_base_parser import TomlParserBase
3129

3230

33-
class TomlV1Parser(IStateParser):
31+
class TomlV1Parser(TomlParserBase):
3432
def __init__(self, dict_to_parse, file_information, schema_validator=None):
35-
self._validator = schema_validator if schema_validator else TomlSchemaV1Validator(dict_to_parse)
36-
self._validator.validate()
37-
38-
self._implementation = None
39-
data_info = self.get_state_data(file_information)
40-
self._implementation = self._get_impl(dict_to_parse, data_info)
41-
self._implementation.parse_all()
33+
validator = schema_validator if schema_validator else TomlSchemaV1Validator(dict_to_parse)
34+
super(TomlV1Parser, self).__init__(dict_to_parse, file_information, validator)
4235

4336
@staticmethod
4437
def _get_impl(*args):
4538
# Wrapper which can replaced with a mock
46-
return _TomlV1ParserImpl(*args)
39+
return TomlV1ParserImpl(*args)
4740

4841
def get_state_data(self, file_information):
4942
state_data = super().get_state_data(file_information)
@@ -76,8 +69,9 @@ def get_state_normalize_to_monitor(self, _):
7669
def get_state_reduction_mode(self):
7770
return self._implementation.reduction_mode
7871

79-
def get_state_polarization(self) -> Optional[StatePolarization]:
80-
return self._implementation.polarization
72+
def get_state_polarization(self) -> StatePolarization:
73+
# Not supported by TOML V1, but we return a blank one to keep the parsing results consistent.
74+
return StatePolarization()
8175

8276
def get_state_save(self):
8377
return StateSave()
@@ -98,9 +92,9 @@ def get_state_wavelength_and_pixel_adjustment(self):
9892
return self._implementation.wavelength_and_pixel
9993

10094

101-
class _TomlV1ParserImpl(TomlParserImplBase):
95+
class TomlV1ParserImpl(TomlParserImplBase):
10296
def __init__(self, input_dict, data_info: StateData):
103-
super(_TomlV1ParserImpl, self).__init__(toml_dict=input_dict)
97+
super(TomlV1ParserImpl, self).__init__(toml_dict=input_dict)
10498
# Always take the instrument from the TOML file rather than guessing in the new parser
10599
data_info.instrument = self.instrument
106100
self._create_state_objs(data_info=data_info)
@@ -118,7 +112,6 @@ def parse_all(self):
118112
self._parse_transmission()
119113
self._parse_transmission_roi()
120114
self._parse_transmission_fitting()
121-
self._parse_polarization()
122115

123116
@property
124117
def instrument(self):
@@ -141,7 +134,6 @@ def _create_state_objs(self, data_info):
141134
self.scale = StateScale()
142135
self.wavelength = StateWavelength()
143136
self.wavelength_and_pixel = get_wavelength_and_pixel_adjustment_builder(data_info=data_info).build()
144-
self.polarization = StatePolarization()
145137

146138
# Ensure they are linked up correctly
147139
self.adjustment.calculate_transmission = self.calculate_transmission
@@ -519,60 +511,6 @@ def _parse_mask(self):
519511
if "stop" in phi_mask:
520512
self.mask.phi_max = phi_mask["stop"]
521513

522-
def _parse_polarization(self):
523-
polarization_dict = self.get_val("polarization")
524-
if polarization_dict is None:
525-
return
526-
self.polarization.flipper_configuration = self.get_val("flipper_configuration", polarization_dict)
527-
self.polarization.spin_configuration = self.get_val("spin_configuration", polarization_dict)
528-
flipper_dicts = self.get_val("flipper", polarization_dict)
529-
if flipper_dicts:
530-
for flipper_dict in flipper_dicts.values():
531-
self.polarization.flippers.append(self._parse_component(flipper_dict))
532-
self.polarization.polarizer = self._parse_filter(self.get_val("polarizer", polarization_dict))
533-
self.polarization.analyzer = self._parse_filter(self.get_val("analyzer", polarization_dict))
534-
self.polarization.magnetic_field = self._parse_field(self.get_val("magnetic_field", polarization_dict))
535-
self.polarization.electric_field = self._parse_field(self.get_val("electric_field", polarization_dict))
536-
self.polarization.validate()
537-
538-
def _parse_component(self, component_dict: dict) -> StateComponent:
539-
component_state = StateComponent()
540-
if component_dict is None:
541-
return component_state
542-
component_state.idf_component_name = self.get_val("idf_component_name", component_dict)
543-
component_state.device_name = self.get_val("device_name", component_dict)
544-
component_state.device_type = self.get_val("device_type", component_dict)
545-
location_dict = self.get_val("location", component_dict)
546-
if location_dict:
547-
component_state.location_x = self.get_val("x", location_dict)
548-
component_state.location_y = self.get_val("y", location_dict)
549-
component_state.location_z = self.get_val("z", location_dict)
550-
component_state.transmission = self.get_val("transmission", component_dict)
551-
component_state.efficiency = self.get_val("efficiency", component_dict)
552-
return component_state
553-
554-
def _parse_filter(self, filter_dict: dict) -> StateFilter:
555-
if filter_dict is None:
556-
return StateFilter()
557-
filter_state = self._parse_component(filter_dict)
558-
filter_state.__class__ = StateFilter
559-
filter_state.cell_length = self.get_val("cell_length", filter_dict)
560-
filter_state.gas_pressure = self.get_val("gas_pressure", filter_dict)
561-
return filter_state
562-
563-
def _parse_field(self, field_dict: dict) -> StateField:
564-
field_state = StateField()
565-
if field_dict is None:
566-
return field_state
567-
field_state.sample_strength_log = self.get_val("sample_strength_log", field_dict)
568-
direction_dict = self.get_val("sample_direction", field_dict)
569-
if direction_dict:
570-
field_state.sample_direction_a = self.get_val("a", direction_dict)
571-
field_state.sample_direction_p = self.get_val("p", direction_dict)
572-
field_state.sample_direction_d = self.get_val("d", direction_dict)
573-
field_state.sample_direction_log = self.get_val("sample_direction_log", field_dict)
574-
return field_state
575-
576514
@staticmethod
577515
def _get_1d_min_max(one_d_binning: str):
578516
# TODO: We have to do some special parsing for this type on behalf of the sans codebase

0 commit comments

Comments
 (0)