Skip to content

Commit b3fa0d8

Browse files
committed
Enhances CIF handling in SampleModelFactory
1 parent 39939a5 commit b3fa0d8

File tree

9 files changed

+224
-217
lines changed

9 files changed

+224
-217
lines changed

src/easydiffraction/experiments/categories/background/chebyshev.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ class PolynomialTerm(CategoryItem):
3535
not break immediately. Tests should migrate to the short names.
3636
"""
3737

38-
def __init__(self, *, order: int, coef: float) -> None:
38+
def __init__(
39+
self,
40+
*,
41+
order=None,
42+
coef=None,
43+
) -> None:
3944
super().__init__()
4045

4146
# Canonical descriptors

src/easydiffraction/experiments/categories/background/line_segment.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
class LineSegment(CategoryItem):
2929
"""Single background control point for interpolation."""
3030

31-
def __init__(self, *, x: float, y: float):
31+
def __init__(
32+
self,
33+
*,
34+
x=None,
35+
y=None,
36+
):
3237
super().__init__()
3338

3439
self._x = NumericDescriptor(
@@ -43,7 +48,12 @@ def __init__(self, *, x: float, y: float):
4348
default=0.0,
4449
content_validator=RangeValidator(),
4550
),
46-
cif_handler=CifHandler(names=['_pd_background.line_segment_X']),
51+
cif_handler=CifHandler(
52+
names=[
53+
'_pd_background.line_segment_X',
54+
'_pd_background_line_segment_X',
55+
]
56+
),
4757
)
4858
self._y = Parameter(
4959
name='y', # TODO: rename to intensity
@@ -57,7 +67,12 @@ def __init__(self, *, x: float, y: float):
5767
default=0.0,
5868
content_validator=RangeValidator(),
5969
), # TODO: rename to intensity
60-
cif_handler=CifHandler(names=['_pd_background.line_segment_intensity']),
70+
cif_handler=CifHandler(
71+
names=[
72+
'_pd_background.line_segment_intensity',
73+
'_pd_background_line_segment_intensity',
74+
]
75+
),
6176
)
6277

6378
self._identity.category_code = 'background'

src/easydiffraction/experiments/categories/excluded_regions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ class ExcludedRegion(CategoryItem):
2121
def __init__(
2222
self,
2323
*,
24-
start: float,
25-
end: float,
24+
start=None,
25+
end=None,
2626
):
2727
super().__init__()
2828

src/easydiffraction/experiments/categories/linked_phases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class LinkedPhase(CategoryItem):
1919
def __init__(
2020
self,
2121
*,
22-
id: str, # TODO: need new name instead of id
23-
scale: float,
22+
id=None, # TODO: need new name instead of id
23+
scale=None,
2424
):
2525
super().__init__()
2626

src/easydiffraction/experiments/experiment/factory.py

Lines changed: 55 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
from easydiffraction.experiments.experiment.enums import RadiationProbeEnum
1414
from easydiffraction.experiments.experiment.enums import SampleFormEnum
1515
from easydiffraction.experiments.experiment.enums import ScatteringTypeEnum
16+
from easydiffraction.io.cif.parse import document_from_path
17+
from easydiffraction.io.cif.parse import name_from_block
18+
from easydiffraction.io.cif.parse import pick_sole_block
19+
from easydiffraction.io.cif.serialize import datastore_from_cif
1620

1721

1822
class ExperimentFactory(FactoryBase):
@@ -42,76 +46,25 @@ class ExperimentFactory(FactoryBase):
4246
}
4347

4448
@classmethod
45-
def create(cls, **kwargs):
46-
"""Create an `ExperimentBase` using a validated argument
47-
combination.
49+
def _make_experiment_type(cls, kwargs):
50+
"""Helper to construct an ExperimentType from keyword arguments,
51+
using defaults as needed.
4852
"""
49-
# Check for valid argument combinations
50-
user_args = {k for k, v in kwargs.items() if v is not None}
51-
cls._validate_args(user_args, cls._ALLOWED_ARG_SPECS, cls.__name__)
52-
53-
# Validate enum arguments if provided
54-
if 'sample_form' in kwargs:
55-
SampleFormEnum(kwargs['sample_form'])
56-
if 'beam_mode' in kwargs:
57-
BeamModeEnum(kwargs['beam_mode'])
58-
if 'radiation_probe' in kwargs:
59-
RadiationProbeEnum(kwargs['radiation_probe'])
60-
if 'scattering_type' in kwargs:
61-
ScatteringTypeEnum(kwargs['scattering_type'])
62-
63-
# Dispatch to the appropriate creation method
64-
if 'cif_path' in kwargs:
65-
return cls._create_from_cif_path(kwargs['cif_path'])
66-
elif 'cif_str' in kwargs:
67-
return cls._create_from_cif_str(kwargs['cif_str'])
68-
elif 'data_path' in kwargs:
69-
return cls._create_from_data_path(kwargs)
70-
elif 'name' in kwargs:
71-
return cls._create_without_data(kwargs)
72-
73-
# -------------
74-
# gemmi helpers
75-
# -------------
76-
77-
# TODO: Move to a common CIF utility module? io.cif.parse?
78-
@staticmethod
79-
def _read_cif_document_from_path(path: str) -> gemmi.cif.Document:
80-
"""Read a CIF document from a file path."""
81-
return gemmi.cif.read_file(path)
82-
83-
# TODO: Move to a common CIF utility module? io.cif.parse?
84-
@staticmethod
85-
def _read_cif_document_from_string(text: str) -> gemmi.cif.Document:
86-
"""Read a CIF document from a raw text string."""
87-
return gemmi.cif.read_string(text)
88-
89-
# TODO: Move to a common CIF utility module? io.cif.parse?
90-
@classmethod
91-
def _pick_first_block(
92-
cls,
93-
doc: gemmi.cif.Document,
94-
) -> gemmi.cif.Block:
95-
"""Pick the first experimental block from a CIF document."""
96-
try:
97-
return doc.sole_block()
98-
except Exception:
99-
return doc[0]
100-
101-
# TODO: Move to a common CIF utility module? io.cif.parse?
102-
@classmethod
103-
def _extract_name_from_block(cls, block: gemmi.cif.Block) -> str:
104-
"""Extract a model name from the CIF block name."""
105-
return block.name or 'model'
53+
return ExperimentType(
54+
sample_form=kwargs.get('sample_form', SampleFormEnum.default()),
55+
beam_mode=kwargs.get('beam_mode', BeamModeEnum.default()),
56+
radiation_probe=kwargs.get('radiation_probe', RadiationProbeEnum.default()),
57+
scattering_type=kwargs.get('scattering_type', ScatteringTypeEnum.default()),
58+
)
10659

10760
# TODO: Move to a common CIF utility module? io.cif.parse?
10861
@classmethod
109-
def _create_experiment_from_block(
62+
def _create_from_gemmi_block(
11063
cls,
11164
block: gemmi.cif.Block,
11265
) -> ExperimentBase:
11366
"""Build a model instance from a single CIF block."""
114-
name = cls._extract_name_from_block(block)
67+
name = name_from_block(block)
11568

11669
# TODO: experiment type need to be read from CIF block
11770
kwargs = {'beam_mode': BeamModeEnum.CONSTANT_WAVELENGTH}
@@ -123,22 +76,24 @@ def _create_experiment_from_block(
12376
expt_class = cls._SUPPORTED[scattering_type][sample_form]
12477
expt_obj = expt_class(name=name, type=expt_type)
12578

126-
# TODO: Read all categories from CIF block into experiment
79+
# Read all categories from CIF block
80+
for category in expt_obj.categories:
81+
category.from_cif(block)
12782

128-
# TODO: Read data from CIF block into experiment datastore
83+
# Populate experiment datastore from CIF block
84+
datastore_from_cif(expt_obj, block)
12985

13086
return expt_obj
13187

132-
# -------------------------------
133-
# Private creation helper methods
134-
# -------------------------------
135-
13688
@classmethod
137-
def _create_from_cif_path(cls, cif_path):
89+
def _create_from_cif_path(
90+
cls,
91+
cif_path: str,
92+
) -> ExperimentBase:
13893
"""Create an experiment from a CIF file path."""
139-
doc = cls._read_cif_document_from_path(cif_path)
140-
block = cls._pick_first_block(doc)
141-
return cls._create_experiment_from_block(block)
94+
doc = document_from_path(cif_path)
95+
block = pick_sole_block(doc)
96+
return cls._create_from_gemmi_block(block)
14297

14398
@staticmethod
14499
def _create_from_cif_str(cif_str):
@@ -182,29 +137,35 @@ def _create_without_data(cls, kwargs):
182137
return expt_obj
183138

184139
@classmethod
185-
def _make_experiment_type(cls, kwargs):
186-
"""Helper to construct an ExperimentType from keyword arguments,
187-
using defaults as needed.
140+
def create(cls, **kwargs):
141+
"""Create an `ExperimentBase` using a validated argument
142+
combination.
188143
"""
189-
return ExperimentType(
190-
sample_form=kwargs.get('sample_form', SampleFormEnum.default()),
191-
beam_mode=kwargs.get('beam_mode', BeamModeEnum.default()),
192-
radiation_probe=kwargs.get('radiation_probe', RadiationProbeEnum.default()),
193-
scattering_type=kwargs.get('scattering_type', ScatteringTypeEnum.default()),
144+
# TODO: move to FactoryBase
145+
# Check for valid argument combinations
146+
user_args = {k for k, v in kwargs.items() if v is not None}
147+
cls._validate_args(
148+
present=user_args,
149+
allowed_specs=cls._ALLOWED_ARG_SPECS,
150+
factory_name=cls.__name__, # TODO: move to FactoryBase
194151
)
195152

196-
@staticmethod
197-
def _is_valid_args(user_args):
198-
"""Validate user argument set against allowed combinations.
153+
# Validate enum arguments if provided
154+
if 'sample_form' in kwargs:
155+
SampleFormEnum(kwargs['sample_form'])
156+
if 'beam_mode' in kwargs:
157+
BeamModeEnum(kwargs['beam_mode'])
158+
if 'radiation_probe' in kwargs:
159+
RadiationProbeEnum(kwargs['radiation_probe'])
160+
if 'scattering_type' in kwargs:
161+
ScatteringTypeEnum(kwargs['scattering_type'])
199162

200-
Returns True if the argument set matches any valid combination,
201-
else False.
202-
"""
203-
user_arg_set = set(user_args)
204-
for arg_set in ExperimentFactory._valid_arg_sets:
205-
required = set(arg_set['required'])
206-
optional = set(arg_set['optional'])
207-
# Must have all required, and only required+optional
208-
if required.issubset(user_arg_set) and user_arg_set <= (required | optional):
209-
return True
210-
return False
163+
# Dispatch to the appropriate creation method
164+
if 'cif_path' in kwargs:
165+
return cls._create_from_cif_path(kwargs['cif_path'])
166+
elif 'cif_str' in kwargs:
167+
return cls._create_from_cif_str(kwargs['cif_str'])
168+
elif 'data_path' in kwargs:
169+
return cls._create_from_data_path(kwargs)
170+
elif 'name' in kwargs:
171+
return cls._create_without_data(kwargs)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-FileCopyrightText: 2021-2025 EasyDiffraction contributors <https://github.com/easyscience/diffraction>
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import gemmi
5+
6+
7+
def document_from_path(path: str) -> gemmi.cif.Document:
8+
"""Read a CIF document from a file path."""
9+
return gemmi.cif.read_file(path)
10+
11+
12+
def document_from_string(text: str) -> gemmi.cif.Document:
13+
"""Read a CIF document from a raw text string."""
14+
return gemmi.cif.read_string(text)
15+
16+
17+
def pick_sole_block(doc: gemmi.cif.Document) -> gemmi.cif.Block:
18+
"""Pick the sole data block from a CIF document."""
19+
return doc.sole_block()
20+
21+
22+
def name_from_block(block: gemmi.cif.Block) -> str:
23+
"""Extract a model name from the CIF block name."""
24+
# TODO: Need validator or normalization?
25+
return block.name

0 commit comments

Comments
 (0)