Skip to content

Commit 37162ab

Browse files
esoteric-ephemeratsmathis
authored andcommitted
ensure that TaskDoc.{input,orig_inputs,calcs_reversed.*.input} share same base class
1 parent b269fb6 commit 37162ab

File tree

1 file changed

+22
-36
lines changed

1 file changed

+22
-36
lines changed

emmet-core/emmet/core/tasks.py

+22-36
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
task_type,
2222
)
2323
from emmet.core.vasp.calculation import (
24-
CalculationBaseModel,
24+
CalculationInput,
2525
Calculation,
2626
PotcarSpec,
2727
RunStatistics,
@@ -61,22 +61,13 @@ class Potcar(BaseModel):
6161
)
6262

6363

64-
class OrigInputs(CalculationBaseModel):
65-
incar: Optional[Union[Incar, Dict]] = Field(
66-
None,
67-
description="Pymatgen object representing the INCAR file.",
68-
)
64+
class OrigInputs(CalculationInput):
6965

7066
poscar: Optional[Poscar] = Field(
7167
None,
7268
description="Pymatgen object representing the POSCAR file.",
7369
)
7470

75-
kpoints: Optional[Kpoints] = Field(
76-
None,
77-
description="Pymatgen object representing the KPOINTS file.",
78-
)
79-
8071
potcar: Optional[Union[Potcar, VaspPotcar, List[Any]]] = Field(
8172
None,
8273
description="Pymatgen object representing the POTCAR file.",
@@ -182,33 +173,26 @@ def from_vasp_calc_doc(
182173
)
183174

184175

185-
class InputDoc(BaseModel):
186-
structure: Optional[Structure] = Field(
187-
None,
188-
title="Input Structure",
189-
description="Output Structure from the VASP calculation.",
190-
)
176+
class InputDoc(CalculationInput):
177+
"""Light wrapper around `CalculationInput` with a few extra fields.
178+
179+
pseudo_potentials (Potcar) : summary of the POTCARs used in the calculation
180+
xc_override (str) : the exchange-correlation functional used if not
181+
the one specified by POTCAR
182+
is_lasph (bool) : how the calculation set LASPH (aspherical corrections)
183+
magnetic_moments (list of floats) : on-site magnetic moments
184+
"""
191185

192-
parameters: Optional[Dict] = Field(
193-
None,
194-
description="Parameters from vasprun for the last calculation in the series",
195-
)
196186
pseudo_potentials: Optional[Potcar] = Field(
197187
None, description="Summary of the pseudo-potentials used in this calculation"
198188
)
199-
potcar_spec: Optional[List[PotcarSpec]] = Field(
200-
None, description="Title and hash of POTCAR files used in the calculation"
201-
)
189+
202190
xc_override: Optional[str] = Field(
203191
None, description="Exchange-correlation functional used if not the default"
204192
)
205193
is_lasph: Optional[bool] = Field(
206194
None, description="Whether the calculation was run with aspherical corrections"
207195
)
208-
is_hubbard: bool = Field(
209-
default=False, description="Is this a Hubbard +U calculation"
210-
)
211-
hubbards: Optional[dict] = Field(None, description="The hubbard parameters used")
212196
magnetic_moments: Optional[List[float]] = Field(
213197
None, description="Magnetic moments for each atom"
214198
)
@@ -238,22 +222,18 @@ def from_vasp_calc_doc(cls, calc_doc: Calculation) -> "InputDoc":
238222
InputDoc
239223
A summary of the input structure and parameters.
240224
"""
241-
xc = calc_doc.input.incar.get("GGA")
225+
xc = calc_doc.input.incar.get("GGA") or calc_doc.input.incar.get("METAGGA")
242226
if xc:
243227
xc = xc.upper()
244228

245229
pot_type, func = calc_doc.input.potcar_type[0].split("_")
246230
func = "lda" if len(pot_type) == 1 else "_".join(func)
247231
pps = Potcar(pot_type=pot_type, functional=func, symbols=calc_doc.input.potcar)
248232
return cls(
249-
structure=calc_doc.input.structure,
250-
parameters=calc_doc.input.parameters,
233+
**calc_doc.input.model_dump(),
251234
pseudo_potentials=pps,
252-
potcar_spec=calc_doc.input.potcar_spec,
253235
xc_override=xc,
254236
is_lasph=calc_doc.input.parameters.get("LASPH", False),
255-
is_hubbard=calc_doc.input.is_hubbard,
256-
hubbards=calc_doc.input.hubbards,
257237
magnetic_moments=calc_doc.input.parameters.get("MAGMOM"),
258238
)
259239

@@ -468,9 +448,15 @@ def model_post_init(self, __context: Any) -> None:
468448
# Always refresh task_type, calc_type, run_type
469449
# See, e.g. https://github.com/materialsproject/emmet/issues/960
470450
# where run_type's were set incorrectly in older versions of TaskDoc
471-
self.task_type = task_type(self.orig_inputs)
451+
452+
# To determine task and run type, we search for input sets in this order
453+
# of precedence: calcs_reversed, inputs, orig_inputs
454+
for inp_set in [self.calcs_reversed[0].input, self.input, self.orig_inputs]:
455+
if inp_set is not None:
456+
break
457+
self.task_type = task_type(inp_set)
472458
self.run_type = self._get_run_type(self.calcs_reversed)
473-
self.calc_type = self._get_calc_type(self.calcs_reversed, self.orig_inputs)
459+
self.calc_type = self._get_calc_type(self.calcs_reversed, inp_set)
474460

475461
# TODO: remove after imposing TaskDoc schema on older tasks in collection
476462
if self.structure is None:

0 commit comments

Comments
 (0)