@@ -28,21 +28,21 @@ def validate_inputs(inputs, _):
28
28
"""Validate the inputs before launching the WorkChain."""
29
29
structure = inputs ['structure' ]
30
30
elements_present = [kind .name for kind in structure .kinds ]
31
- absorbing_elements_list = sorted (inputs ['elements_list' ])
32
31
abs_atom_marker = inputs ['abs_atom_marker' ].value
33
32
if abs_atom_marker in elements_present :
34
33
raise ValidationError (
35
34
f'The marker given for the absorbing atom ("{ abs_atom_marker } ") matches an existing Kind in the '
36
35
f'input structure ({ elements_present } ).'
37
36
)
38
-
39
- if inputs ['calc_binding_energy' ].value :
40
- ce_list = sorted (inputs ['correction_energies' ].get_dict ().keys ())
41
- if ce_list != absorbing_elements_list :
42
- raise ValidationError (
43
- f'The ``correction_energies`` provided ({ ce_list } ) does not match the list of'
44
- f' absorbing elements ({ absorbing_elements_list } )'
45
- )
37
+ if 'elements_list' in inputs :
38
+ absorbing_elements_list = sorted (inputs ['elements_list' ])
39
+ if inputs ['calc_binding_energy' ].value :
40
+ ce_list = sorted (inputs ['correction_energies' ].get_dict ().keys ())
41
+ if ce_list != absorbing_elements_list :
42
+ raise ValidationError (
43
+ f'The ``correction_energies`` provided ({ ce_list } ) does not match the list of'
44
+ f' absorbing elements ({ absorbing_elements_list } )'
45
+ )
46
46
47
47
48
48
class XpsWorkChain (ProtocolMixin , WorkChain ):
@@ -81,7 +81,7 @@ def define(cls, spec):
81
81
spec .expose_inputs (
82
82
PwBaseWorkChain ,
83
83
namespace = 'ch_scf' ,
84
- exclude = ('kpoints' , ' pw.structure' ),
84
+ exclude = ('pw.structure' , ),
85
85
namespace_options = {
86
86
'help' : ('Input parameters for the basic xps workflow (core-hole SCF).' ),
87
87
'validator' : None
@@ -170,6 +170,14 @@ def define(cls, spec):
170
170
'The list of elements to be considered for analysis, each must be valid elements of the periodic table.'
171
171
)
172
172
)
173
+ spec .input (
174
+ 'atoms_list' ,
175
+ valid_type = orm .List ,
176
+ required = False ,
177
+ help = (
178
+ 'The indices of atoms to be considered for analysis.'
179
+ )
180
+ )
173
181
spec .input (
174
182
'calc_binding_energy' ,
175
183
valid_type = orm .Bool ,
@@ -233,12 +241,14 @@ def define(cls, spec):
233
241
spec .output (
234
242
'supercell_structure' ,
235
243
valid_type = orm .StructureData ,
244
+ required = False ,
236
245
help = ('The supercell of ``outputs.standardized_structure`` used to generate structures for'
237
246
' XPS sub-processes.' )
238
247
)
239
248
spec .output (
240
249
'symmetry_analysis_data' ,
241
250
valid_type = orm .Dict ,
251
+ required = False ,
242
252
help = 'The output parameters from ``get_xspectra_structures()``.'
243
253
)
244
254
spec .output (
@@ -366,8 +376,8 @@ def get_treatment_filepath(cls):
366
376
@classmethod
367
377
def get_builder_from_protocol (
368
378
cls , code , structure , pseudos , core_hole_treatments = None , protocol = None ,
369
- overrides = None , elements_list = None , options = None ,
370
- structure_preparation_settings = None , ** kwargs
379
+ overrides = None , elements_list = None , atoms_list = None , options = None ,
380
+ structure_preparation_settings = None , correction_energies = None , ** kwargs
371
381
):
372
382
"""Return a builder prepopulated with inputs selected according to the chosen protocol.
373
383
@@ -386,9 +396,6 @@ def get_builder_from_protocol(
386
396
"""
387
397
388
398
inputs = cls .get_protocol_inputs (protocol , overrides )
389
- calc_binding_energy = kwargs .pop ('calc_binding_energy' , False )
390
- correction_energies = kwargs .pop ('correction_energies' , orm .Dict ())
391
-
392
399
pw_args = (code , structure , protocol )
393
400
# xspectra_args = (pw_code, xs_code, structure, protocol, upf2plotcore_code)
394
401
@@ -412,8 +419,11 @@ def get_builder_from_protocol(
412
419
builder .ch_scf = ch_scf
413
420
builder .structure = structure
414
421
builder .abs_atom_marker = abs_atom_marker
415
- builder .calc_binding_energy = calc_binding_energy
416
- builder .correction_energies = correction_energies
422
+ if correction_energies :
423
+ builder .correction_energies = orm .Dict (correction_energies )
424
+ builder .calc_binding_energy = orm .Bool (True )
425
+ else :
426
+ builder .calc_binding_energy = orm .Bool (False )
417
427
builder .clean_workdir = orm .Bool (inputs ['clean_workdir' ])
418
428
core_hole_pseudos = {}
419
429
gipaw_pseudos = {}
@@ -434,6 +444,12 @@ def get_builder_from_protocol(
434
444
for element in elements_list :
435
445
core_hole_pseudos [element ] = pseudos [element ]['core_hole' ]
436
446
gipaw_pseudos [element ] = pseudos [element ]['gipaw' ]
447
+ elif atoms_list :
448
+ builder .atoms_list = orm .List (atoms_list )
449
+ for index in atoms_list :
450
+ element = structure .sites [index ].kind_name
451
+ core_hole_pseudos [element ] = pseudos [element ]['core_hole' ]
452
+ gipaw_pseudos [element ] = pseudos [element ]['gipaw' ]
437
453
# if no elements list is given, we instead initalise the pseudos dict with all
438
454
# elements in the structure
439
455
else :
@@ -453,12 +469,18 @@ def get_builder_from_protocol(
453
469
454
470
def setup (self ):
455
471
"""Init required context variables."""
456
- custom_elements_list = self .inputs .get ('elements_list' , None )
457
- if not custom_elements_list :
472
+ elements_list = self .inputs .get ('elements_list' , None )
473
+ atoms_list = self .inputs .get ('atoms_list' , None )
474
+ if elements_list :
475
+ self .ctx .elements_list = elements_list .get_list ()
476
+ self .ctx .atoms_list = None
477
+ elif atoms_list :
478
+ self .ctx .atoms_list = atoms_list .get_list ()
479
+ self .ctx .elements_list = None
480
+ else :
458
481
structure = self .inputs .structure
459
482
self .ctx .elements_list = [Kind .symbol for Kind in structure .kinds ]
460
- else :
461
- self .ctx .elements_list = custom_elements_list .get_list ()
483
+
462
484
463
485
464
486
def should_run_relax (self ):
@@ -511,48 +533,59 @@ def prepare_structures(self):
511
533
formatted as {<variable_name> : <parameter>} for each variable in the
512
534
``get_symmetry_dataset()`` method.
513
535
"""
536
+ from aiida_quantumespresso .workflows .functions .get_marked_structures import get_marked_structures
514
537
from aiida_quantumespresso .workflows .functions .get_xspectra_structures import get_xspectra_structures
515
538
516
- elements_list = orm .List (self .ctx .elements_list )
517
- inputs = {
518
- 'absorbing_elements_list' : elements_list ,
519
- 'absorbing_atom_marker' : self .inputs .abs_atom_marker ,
520
- 'metadata' : {
521
- 'call_link_label' : 'get_xspectra_structures'
539
+ input_structure = self .inputs .structure if 'relax' not in self .inputs else self .ctx .relaxed_structure
540
+ if self .ctx .elements_list :
541
+ elements_list = orm .List (self .ctx .elements_list )
542
+ inputs = {
543
+ 'absorbing_elements_list' : elements_list ,
544
+ 'absorbing_atom_marker' : self .inputs .abs_atom_marker ,
545
+ 'metadata' : {
546
+ 'call_link_label' : 'get_xspectra_structures'
547
+ }
548
+ } # populate this further once the schema for WorkChain options is figured out
549
+ if 'structure_preparation_settings' in self .inputs :
550
+ optional_cell_prep = self .inputs .structure_preparation_settings
551
+ for key , node in optional_cell_prep .items ():
552
+ inputs [key ] = node
553
+ if 'spglib_settings' in self .inputs :
554
+ spglib_settings = self .inputs .spglib_settings
555
+ inputs ['spglib_settings' ] = spglib_settings
556
+ else :
557
+ spglib_settings = None
558
+
559
+ result = get_xspectra_structures (input_structure , ** inputs )
560
+
561
+ supercell = result .pop ('supercell' )
562
+ out_params = result .pop ('output_parameters' )
563
+ if out_params .get_dict ().get ('structure_is_standardized' , None ):
564
+ standardized = result .pop ('standardized_structure' )
565
+ self .out ('standardized_structure' , standardized )
566
+
567
+ # structures_to_process = {Key : Value for Key, Value in result.items()}
568
+ for site in ['output_parameters' , 'supercell' , 'standardized_structure' ]:
569
+ result .pop (site , None )
570
+ self .ctx .supercell = supercell
571
+ self .ctx .equivalent_sites_data = out_params ['equivalent_sites_data' ]
572
+ self .out ('supercell_structure' , supercell )
573
+ self .out ('symmetry_analysis_data' , out_params )
574
+ elif self .ctx .atoms_list :
575
+ atoms_list = orm .List (self .ctx .atoms_list )
576
+ inputs = {
577
+ 'atoms_list' : atoms_list ,
578
+ 'marker' : self .inputs .abs_atom_marker ,
579
+ 'metadata' : {
580
+ 'call_link_label' : 'get_marked_structures'
581
+ }
522
582
}
523
- } # populate this further once the schema for WorkChain options is figured out
524
- if 'structure_preparation_settings' in self .inputs :
525
- optional_cell_prep = self .inputs .structure_preparation_settings
526
- for key , node in optional_cell_prep .items ():
527
- inputs [key ] = node
528
- if 'spglib_settings' in self .inputs :
529
- spglib_settings = self .inputs .spglib_settings
530
- inputs ['spglib_settings' ] = spglib_settings
531
- else :
532
- spglib_settings = None
533
-
534
- if 'relax' in self .inputs :
535
- relaxed_structure = self .ctx .relaxed_structure
536
- result = get_xspectra_structures (relaxed_structure , ** inputs )
537
- else :
538
- result = get_xspectra_structures (self .inputs .structure , ** inputs )
539
-
540
- supercell = result .pop ('supercell' )
541
- out_params = result .pop ('output_parameters' )
542
- if out_params .get_dict ().get ('structure_is_standardized' , None ):
543
- standardized = result .pop ('standardized_structure' )
544
- self .out ('standardized_structure' , standardized )
545
-
546
- # structures_to_process = {Key : Value for Key, Value in result.items()}
547
- for site in ['output_parameters' , 'supercell' , 'standardized_structure' ]:
548
- result .pop (site , None )
583
+ result = get_marked_structures (input_structure , ** inputs )
584
+ self .ctx .supercell = input_structure
585
+ self .ctx .equivalent_sites_data = result .pop ('output_parameters' ).get_dict ()
549
586
structures_to_process = {f'{ Key .split ("_" )[0 ]} _{ Key .split ("_" )[1 ]} ' : Value for Key , Value in result .items ()}
550
- self .ctx . supercell = supercell
587
+ self .report ( f'structures_to_process: { structures_to_process } ' )
551
588
self .ctx .structures_to_process = structures_to_process
552
- self .ctx .equivalent_sites_data = out_params ['equivalent_sites_data' ]
553
-
554
- self .out ('supercell_structure' , supercell )
555
- self .out ('symmetry_analysis_data' , out_params )
556
589
557
590
def should_run_gs_scf (self ):
558
591
"""If the 'calc_binding_energy' input namespace is True, we run a scf calculation for the supercell."""
@@ -566,9 +599,9 @@ def run_gs_scf(self):
566
599
inputs .metadata .call_link_label = 'supercell_xps'
567
600
568
601
inputs = prepare_process_inputs (PwBaseWorkChain , inputs )
569
- equivalent_sites_data = self . ctx . equivalent_sites_data
570
- for site in equivalent_sites_data :
571
- abs_element = equivalent_sites_data [site ]['symbol' ]
602
+ # pseudos for all elements to be calculated should be replaced
603
+ for site in self . ctx . equivalent_sites_data :
604
+ abs_element = self . ctx . equivalent_sites_data [site ]['symbol' ]
572
605
inputs .pw .pseudos [abs_element ] = self .inputs .gipaw_pseudos [abs_element ]
573
606
running = self .submit (PwBaseWorkChain , ** inputs )
574
607
@@ -600,7 +633,6 @@ def run_all_scf(self):
600
633
equivalent_sites_data = self .ctx .equivalent_sites_data
601
634
abs_atom_marker = self .inputs .abs_atom_marker .value
602
635
603
-
604
636
for site in structures_to_process :
605
637
inputs = AttributeDict (self .exposed_inputs (PwBaseWorkChain , namespace = 'ch_scf' ))
606
638
structure = structures_to_process [site ]
@@ -630,9 +662,10 @@ def run_all_scf(self):
630
662
631
663
core_hole_pseudo = self .inputs .core_hole_pseudos [abs_element ]
632
664
inputs .pw .pseudos [abs_atom_marker ] = core_hole_pseudo
633
- # all element in the elements_list should be replaced
634
- for element in self .inputs .elements_list :
635
- inputs .pw .pseudos [element ] = self .inputs .gipaw_pseudos [element ]
665
+ # pseudos for all elements to be calculated should be replaced
666
+ for key in self .ctx .equivalent_sites_data :
667
+ abs_element = self .ctx .equivalent_sites_data [key ]['symbol' ]
668
+ inputs .pw .pseudos [abs_element ] = self .inputs .gipaw_pseudos [abs_element ]
636
669
# remove pseudo if the only element is replaced by the marker
637
670
inputs .pw .pseudos = {kind .name : inputs .pw .pseudos [kind .name ] for kind in structure .kinds }
638
671
@@ -674,11 +707,15 @@ def results(self):
674
707
kwargs ['correction_energies' ] = self .inputs .correction_energies
675
708
kwargs ['metadata' ] = {'call_link_label' : 'compile_final_spectra' }
676
709
677
- equivalent_sites_data = orm .Dict (dict = self .ctx .equivalent_sites_data )
678
- elements_list = orm .List (list = self .ctx .elements_list )
710
+ if self .ctx .elements_list :
711
+ elements_list = orm .List (list = self .ctx .elements_list )
712
+ else :
713
+ symbols = {value ['symbol' ] for value in self .ctx .equivalent_sites_data .values ()}
714
+ elements_list = orm .List (list (symbols ))
679
715
voight_gamma = self .inputs .voight_gamma
680
716
voight_sigma = self .inputs .voight_sigma
681
717
718
+ equivalent_sites_data = orm .Dict (dict = self .ctx .equivalent_sites_data )
682
719
result = get_spectra_by_element (
683
720
elements_list ,
684
721
equivalent_sites_data ,
0 commit comments