Skip to content

Commit 66f3328

Browse files
Bvr4pre-commit-ci[bot]peterdudfield
authored
Change Pydantic from_orm to model_validate (#284) (#285)
* change pydantic from_orm to model_validate in tests, except from those related to Forecast model (#284) * change pydantic from_orm to model_validate in Forecast and ForecastValue models (#284) Updated from_orm class method in Forecast and ForecastValue models, so they call model_validate instead of the deprecated from_orm method. Created a class method model_validate in Forecast and ForecastValue models, so they work as from_orm class methods worked. * Change from_orm call to model_validate in nationay.py Updated Forecast and ForecastValue models from_orm call to model_validate call. * Change from_orm call to model_validate in convert.py Updated Forecast and ForecastValue models from_orm call to model_validate call. * Add model_validate_latest method in Forecast class Created model_validate_latest method that does the same thing as from_orm_latest. So it is consistent to the syntax of pydantic V2 * change pydantic from_orm to model_validate in tests related to Forecast and ForecastValue models * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade dockerfile to python 3.10 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Peter Dudfield <[email protected]>
1 parent 9f0cf8a commit 66f3328

File tree

10 files changed

+120
-45
lines changed

10 files changed

+120
-45
lines changed

nowcasting_datamodel/models/convert.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def convert_list_forecast_value_seven_days_sql_to_list_forecast(
3737
for forecast_value_sql in forecast_values_sql:
3838
gsp_id = forecast_value_sql.forecast.location.gsp_id
3939

40-
forecast_value: ForecastValue = ForecastValue.from_orm(forecast_value_sql)
40+
forecast_value: ForecastValue = ForecastValue.model_validate(
41+
forecast_value_sql, from_attributes=True
42+
)
4143

4244
if gsp_id in forecasts_by_gsp.keys():
4345
forecasts_by_gsp[gsp_id].forecast_values.append(forecast_value)
@@ -50,7 +52,7 @@ def convert_list_forecast_value_seven_days_sql_to_list_forecast(
5052
forecast_values=[forecast_value_sql],
5153
historic=forecast_value_sql.forecast.historic,
5254
)
53-
forecast = Forecast.from_orm(forecast)
55+
forecast = Forecast.model_validate(forecast, from_attributes=True)
5456
forecasts_by_gsp[gsp_id] = forecast
5557

5658
forecasts = [forecast for forecast in forecasts_by_gsp.values()]

nowcasting_datamodel/models/forecast.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,28 @@ def to_orm(self) -> ForecastValueSQL:
354354
@classmethod
355355
def from_orm(cls, obj: ForecastValueSQL):
356356
"""Make sure _adjust_mw is transfered also"""
357-
m = super().from_orm(obj=obj)
357+
m = super().model_validate(obj=obj, from_attributes=True)
358+
359+
# this is because from orm doesnt copy over '_' variables.
360+
# But we don't want to expose this in the API
361+
default_value = 0.0
362+
if hasattr(obj, "adjust_mw"):
363+
adjust_mw = obj.adjust_mw
364+
if not adjust_mw or np.isnan(adjust_mw):
365+
adjust_mw = default_value
366+
m._adjust_mw = adjust_mw
367+
else:
368+
m._adjust_mw = default_value
369+
370+
if hasattr(obj, "properties"):
371+
m._properties = obj.properties
372+
373+
return m
374+
375+
@classmethod
376+
def model_validate(cls, obj: ForecastValueSQL, from_attributes: bool | None = None):
377+
"""Make sure _adjust_mw is transfered also"""
378+
m = super().model_validate(obj=obj, from_attributes=from_attributes)
358379

359380
# this is because from orm doesnt copy over '_' variables.
360381
# But we don't want to expose this in the API
@@ -497,20 +518,36 @@ def to_orm(self) -> ForecastSQL:
497518

498519
@classmethod
499520
def from_orm(cls, forecast_sql: ForecastSQL):
500-
"""Method to make Forecast object from ForecastSQL,
521+
"""Method to make Forecast object from ForecastSQL"""
522+
# do normal transform
523+
return Forecast(
524+
forecast_creation_time=forecast_sql.forecast_creation_time,
525+
location=Location.model_validate(forecast_sql.location, from_attributes=True),
526+
input_data_last_updated=InputDataLastUpdated.model_validate(
527+
forecast_sql.input_data_last_updated, from_attributes=True
528+
),
529+
forecast_values=[
530+
ForecastValue.model_validate(forecast_value, from_attributes=True)
531+
for forecast_value in forecast_sql.forecast_values
532+
],
533+
historic=forecast_sql.historic,
534+
model=MLModel.model_validate(forecast_sql.model),
535+
)
501536

502-
but move 'forecast_values_latest' to 'forecast_values'
503-
This is useful as we want the API to still present a Forecast object.
504-
"""
537+
@classmethod
538+
def model_validate(cls, forecast_sql: ForecastSQL, from_attributes: bool | None = None):
539+
"""Method to make Forecast object from ForecastSQL"""
505540
# do normal transform
506541
return Forecast(
507542
forecast_creation_time=forecast_sql.forecast_creation_time,
508-
location=Location.from_orm(forecast_sql.location),
509-
input_data_last_updated=InputDataLastUpdated.from_orm(
510-
forecast_sql.input_data_last_updated
543+
location=Location.model_validate(
544+
forecast_sql.location, from_attributes=from_attributes
545+
),
546+
input_data_last_updated=InputDataLastUpdated.model_validate(
547+
forecast_sql.input_data_last_updated, from_attributes=from_attributes
511548
),
512549
forecast_values=[
513-
ForecastValue.from_orm(forecast_value)
550+
ForecastValue.model_validate(forecast_value, from_attributes=from_attributes)
514551
for forecast_value in forecast_sql.forecast_values
515552
],
516553
historic=forecast_sql.historic,
@@ -525,11 +562,29 @@ def from_orm_latest(cls, forecast_sql: ForecastSQL):
525562
This is useful as we want the API to still present a Forecast object.
526563
"""
527564
# do normal transform
528-
forecast = cls.from_orm(forecast_sql)
565+
forecast = cls.model_validate(forecast_sql, from_attributes=True)
566+
567+
# move 'forecast_values_latest' to 'forecast_values'
568+
forecast.forecast_values = [
569+
ForecastValue.model_validate(forecast_value, from_attributes=True)
570+
for forecast_value in forecast_sql.forecast_values_latest
571+
]
572+
573+
return forecast
574+
575+
@classmethod
576+
def model_validate_latest(cls, forecast_sql: ForecastSQL, from_attributes: bool | None = None):
577+
"""Method to make Forecast object from ForecastSQL,
578+
579+
but move 'forecast_values_latest' to 'forecast_values'
580+
This is useful as we want the API to still present a Forecast object.
581+
"""
582+
# do normal transform
583+
forecast = cls.model_validate(forecast_sql, from_attributes=from_attributes)
529584

530585
# move 'forecast_values_latest' to 'forecast_values'
531586
forecast.forecast_values = [
532-
ForecastValue.from_orm(forecast_value)
587+
ForecastValue.model_validate(forecast_value, from_attributes=from_attributes)
533588
for forecast_value in forecast_sql.forecast_values_latest
534589
]
535590

nowcasting_datamodel/national.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ def make_national_forecast(
4545
gsp_id = forecast.location.gsp_id
4646

4747
one_gsp = pd.DataFrame(
48-
[ForecastValue.from_orm(value).dict() for value in forecast.forecast_values]
48+
[
49+
ForecastValue.model_validate(value, from_attributes=True).model_dump()
50+
for value in forecast.forecast_values
51+
]
4952
)
5053
adjusts_mw = [f.adjust_mw for f in forecast.forecast_values]
5154
one_gsp["gps_id"] = gsp_id
@@ -98,6 +101,6 @@ def make_national_forecast(
98101
)
99102

100103
# validate
101-
_ = Forecast.from_orm(national_forecast)
104+
_ = Forecast.model_validate(national_forecast, from_attributes=True)
102105

103106
return national_forecast

nowcasting_datamodel/save/adjust.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ def reduce_metric_values_to_correct_forecast_horizon(
189189
f"Reducing metric values to correct forecast horizon {datetime_now=} {hours_ahead=}"
190190
)
191191

192-
latest_me_df = pd.DataFrame([MetricValue.from_orm(m).dict() for m in latest_me])
192+
latest_me_df = pd.DataFrame(
193+
[MetricValue.model_validate(m, from_attributes=True).model_dump() for m in latest_me]
194+
)
193195
if len(latest_me_df) == 0:
194196
# no latest ME values, so just making an empty dataframe
195197
latest_me_df = pd.DataFrame(columns=["forecast_horizon_minutes", "time_of_day", "value"])

tests/models/test_models.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,51 +19,51 @@
1919
def test_adjust_forecasts(forecasts):
2020
forecasts[0].forecast_values[0].expected_power_generation_megawatts = 10.0
2121
forecasts[0].forecast_values[0].adjust_mw = 1.23
22-
forecasts = [Forecast.from_orm(f) for f in forecasts]
22+
forecasts = [Forecast.model_validate(f, from_attributes=True) for f in forecasts]
2323

2424
assert forecasts[0].forecast_values[0]._adjust_mw == 1.23
2525

2626
forecasts[0].adjust(limit=1.22)
2727
assert forecasts[0].forecast_values[0].expected_power_generation_megawatts == 8.78
28-
assert "expected_power_generation_megawatts" in forecasts[0].forecast_values[0].dict()
29-
assert "_adjust_mw" not in forecasts[0].forecast_values[0].dict()
28+
assert "expected_power_generation_megawatts" in forecasts[0].forecast_values[0].model_dump()
29+
assert "_adjust_mw" not in forecasts[0].forecast_values[0].model_dump()
3030

3131

3232
def test_adjust_forecast_neg(forecasts):
3333
forecasts[0].forecast_values[0].expected_power_generation_megawatts = 10.0
3434
forecasts[0].forecast_values[0].adjust_mw = -1.23
35-
forecasts = [Forecast.from_orm(f) for f in forecasts]
35+
forecasts = [Forecast.model_validate(f, from_attributes=True) for f in forecasts]
3636

3737
forecasts[0].adjust(limit=1.22)
3838
assert forecasts[0].forecast_values[0].expected_power_generation_megawatts == 11.22
39-
assert "expected_power_generation_megawatts" in forecasts[0].forecast_values[0].dict()
40-
assert "_adjust_mw" not in forecasts[0].forecast_values[0].dict()
39+
assert "expected_power_generation_megawatts" in forecasts[0].forecast_values[0].model_dump()
40+
assert "_adjust_mw" not in forecasts[0].forecast_values[0].model_dump()
4141

4242

4343
def test_adjust_forecast_below_zero(forecasts):
4444
v = forecasts[0].forecast_values[0].expected_power_generation_megawatts
4545
forecasts[0].forecast_values[0].adjust_mw = v + 100
46-
forecasts = [Forecast.from_orm(f) for f in forecasts]
46+
forecasts = [Forecast.model_validate(f, from_attributes=True) for f in forecasts]
4747
forecasts[0].forecast_values[0]._properties = {"10": v - 100}
4848

4949
forecasts[0].adjust(limit=v * 3)
5050

5151
assert forecasts[0].forecast_values[0].expected_power_generation_megawatts == 0.0
5252
assert forecasts[0].forecast_values[0]._properties["10"] == 0.0
53-
assert "expected_power_generation_megawatts" in forecasts[0].forecast_values[0].dict()
54-
assert "_adjust_mw" not in forecasts[0].forecast_values[0].dict()
53+
assert "expected_power_generation_megawatts" in forecasts[0].forecast_values[0].model_dump()
54+
assert "_adjust_mw" not in forecasts[0].forecast_values[0].model_dump()
5555

5656

5757
def test_adjust_many_forecasts(forecasts):
5858
forecasts[0].forecast_values[0].adjust_mw = 1.23
59-
forecasts = [Forecast.from_orm(f) for f in forecasts]
59+
forecasts = [Forecast.model_validate(f, from_attributes=True) for f in forecasts]
6060
m = ManyForecasts(forecasts=forecasts)
6161
m.adjust()
6262

6363

6464
def test_normalize_forecasts(forecasts):
6565
v = forecasts[0].forecast_values[0].expected_power_generation_megawatts
66-
forecasts_all = [Forecast.from_orm(f) for f in forecasts]
66+
forecasts_all = [Forecast.model_validate(f, from_attributes=True) for f in forecasts]
6767

6868
forecasts_all[0].normalize()
6969
assert (
@@ -76,7 +76,7 @@ def test_normalize_forecasts(forecasts):
7676

7777

7878
def test_normalize_forecasts_no_installed_capacity(forecasts):
79-
forecast = Forecast.from_orm(forecasts[0])
79+
forecast = Forecast.model_validate(forecasts[0], from_attributes=True)
8080
forecast.location.installed_capacity_mw = None
8181

8282
v = forecast.forecast_values[0].expected_power_generation_megawatts
@@ -98,7 +98,7 @@ def test_status_validation():
9898
def test_status_orm():
9999
status = Status(message="testing", status="warning")
100100
ormed_status = status.to_orm()
101-
status_orm = Status.from_orm(ormed_status)
101+
status_orm = Status.model_validate(ormed_status, from_attributes=True)
102102

103103
assert status_orm.message == status.message
104104
assert status_orm.status == status.status
@@ -123,6 +123,12 @@ def test_forecast_latest_to_pydantic(forecast_sql):
123123
forecast = Forecast.from_orm_latest(forecast_sql=forecast_sql)
124124
assert forecast.forecast_values[0] == ForecastValue.from_orm(f1)
125125

126+
forecast = Forecast.model_validate(forecast_sql, from_attributes=True)
127+
assert forecast.forecast_values[0] != ForecastValue.model_validate(f1, from_attributes=True)
128+
129+
forecast = Forecast.model_validate_latest(forecast_sql=forecast_sql, from_attributes=True)
130+
assert forecast.forecast_values[0] == ForecastValue.model_validate(f1, from_attributes=True)
131+
126132

127133
def test_forecast_value_from_orm(forecast_sql):
128134
forecast_sql = forecast_sql[0]
@@ -131,7 +137,7 @@ def test_forecast_value_from_orm(forecast_sql):
131137
target_time=datetime(2023, 1, 1, 0, 30), expected_power_generation_megawatts=1
132138
)
133139

134-
actual = ForecastValue.from_orm(f)
140+
actual = ForecastValue.model_validate(f, from_attributes=True)
135141
expected = ForecastValue(
136142
target_time=datetime(2023, 1, 1, 0, 30, tzinfo=timezone.utc),
137143
expected_power_generation_megawatts=1.0,
@@ -148,7 +154,7 @@ def test_forecast_value_from_orm_from_adjust_mw_nan(forecast_sql, null_value):
148154
)
149155
f.adjust_mw = null_value
150156

151-
actual = ForecastValue.from_orm(f)
157+
actual = ForecastValue.model_validate(f, from_attributes=True)
152158
expected = ForecastValue(
153159
target_time=datetime(2023, 1, 1, 0, 30, tzinfo=timezone.utc),
154160
expected_power_generation_megawatts=1.0,

tests/read/test_read.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def test_get_forecast(db_session, forecasts):
7575
assert forecast_read.forecast_values[0] == forecasts[-1].forecast_values[0]
7676

7777
_ = Forecast.from_orm(forecast_read)
78+
_ = Forecast.model_validate(forecast_read, from_attributes=True)
7879

7980

8081
def test_read_gsp_id(db_session, forecasts):
@@ -125,6 +126,7 @@ def test_get_forecast_values_gsp_id(db_session, forecasts):
125126
)
126127

127128
_ = ForecastValue.from_orm(forecast_values_read[0])
129+
_ = ForecastValue.model_validate(forecast_values_read[0], from_attributes=True)
128130

129131
assert len(forecast_values_read) == N_FAKE_FORECASTS
130132

@@ -152,7 +154,7 @@ def test_get_forecast_values_latest_gsp_id(db_session):
152154
forecast_values_read = get_forecast_values_latest(
153155
session=db_session, gsp_id=f1[0].location.gsp_id
154156
)
155-
_ = ForecastValue.from_orm(forecast_values_read[0])
157+
_ = ForecastValue.model_validate(forecast_values_read[0], from_attributes=True)
156158

157159
assert len(forecast_values_read) == 2
158160
assert forecast_values_read[0].gsp_id == f1[0].location.gsp_id
@@ -220,7 +222,7 @@ def test_get_forecast_values_gsp_id_latest(db_session):
220222
start_datetime=datetime(2024, 1, 2, tzinfo=timezone.utc),
221223
)
222224

223-
_ = ForecastValue.from_orm(forecast_values_read[0])
225+
_ = ForecastValue.model_validate(forecast_values_read[0], from_attributes=True)
224226

225227
assert len(forecast_values_read) == 16 # only getting forecast ahead
226228

@@ -244,7 +246,7 @@ def test_get_forecast_values_start_and_creation(db_session):
244246
created_utc_limit=datetime(2024, 1, 1, tzinfo=timezone.utc),
245247
)
246248

247-
_ = ForecastValue.from_orm(forecast_values_read[0])
249+
_ = ForecastValue.model_validate(forecast_values_read[0], from_attributes=True)
248250

249251
assert len(forecast_values_read) == 76 # only getting forecast ahead
250252

@@ -402,15 +404,17 @@ def test_get_national_latest_forecast(db_session):
402404

403405

404406
def test_get_pv_system(db_session_pv):
405-
pv_system = PVSystem.from_orm(make_fake_pv_system())
407+
pv_system = PVSystem.model_validate(make_fake_pv_system(), from_attributes=True)
406408
save_pv_system(session=db_session_pv, pv_system=pv_system)
407409

408410
pv_system_get = get_pv_system(
409411
session=db_session_pv, provider=pv_system.provider, pv_system_id=pv_system.pv_system_id
410412
)
411413
# this get defaulted to True when adding to the database
412414
pv_system.correct_data = True
413-
assert PVSystem.from_orm(pv_system) == PVSystem.from_orm(pv_system_get)
415+
assert PVSystem.model_validate(pv_system, from_attributes=True) == PVSystem.model_validate(
416+
pv_system_get, from_attributes=True
417+
)
414418

415419

416420
def test_get_latest_input_data_last_updated_multiple_entries(db_session):

tests/read/test_read_gsp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,10 @@ def test_get_gsp_yield_by_location(db_session):
203203
assert locations_with_gsp_yields[0].gsp_id == 1
204204
assert len(locations_with_gsp_yields[0].gsp_yields) == 2
205205

206-
locations = [LocationWithGSPYields.from_orm(location) for location in locations_with_gsp_yields]
206+
locations = [
207+
LocationWithGSPYields.model_validate(location, from_attributes=True)
208+
for location in locations_with_gsp_yields
209+
]
207210
assert len(locations[0].gsp_yields) == 2
208211
assert locations_with_gsp_yields[0].gsp_yields[0].datetime_utc.tzinfo == timezone.utc
209212

tests/read/test_read_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,4 @@ def test_get_model(db_session):
5353
assert model_read_1.name == model_read_2.name
5454
assert model_read_1.version == model_read_2.version
5555

56-
_ = MLModel.from_orm(model_read_2)
56+
_ = MLModel.model_validate(model_read_2, from_attributes=True)

tests/test_fake.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,29 +36,29 @@ def test_make_fake_intensity():
3636

3737
def test_make_fake_location():
3838
location_sql: LocationSQL = make_fake_location(1)
39-
location = Location.from_orm(location_sql)
39+
location = Location.model_validate(location_sql, from_attributes=True)
4040
_ = Location.to_orm(location)
4141

4242

4343
def test_make_fake_input_data_last_updated():
4444
input_sql: InputDataLastUpdatedSQL = make_fake_input_data_last_updated()
45-
input = InputDataLastUpdated.from_orm(input_sql)
45+
input = InputDataLastUpdated.model_validate(input_sql, from_attributes=True)
4646
_ = InputDataLastUpdated.to_orm(input)
4747

4848

4949
def test_make_fake_forecast_value():
5050
target = datetime(2023, 1, 1, tzinfo=timezone.utc)
5151

5252
forecast_value_sql: ForecastValueSQL = make_fake_forecast_value(target_time=target)
53-
forecast_value = ForecastValue.from_orm(forecast_value_sql)
53+
forecast_value = ForecastValue.model_validate(forecast_value_sql, from_attributes=True)
5454
_ = ForecastValue.to_orm(forecast_value)
5555

5656

5757
def test_make_fake_forecast(db_session):
5858
forecast_sql: ForecastSQL = make_fake_forecast(gsp_id=1, session=db_session)
59-
forecast = Forecast.from_orm(forecast_sql)
59+
forecast = Forecast.model_validate(forecast_sql, from_attributes=True)
6060
forecast_sql = Forecast.to_orm(forecast)
61-
_ = Forecast.from_orm(forecast_sql)
61+
_ = Forecast.model_validate(forecast_sql, from_attributes=True)
6262

6363
from sqlalchemy import text
6464

@@ -78,7 +78,7 @@ def test_make_fake_forecasts(db_session):
7878

7979
def test_make_national_fake_forecast(db_session):
8080
forecast_sql: ForecastSQL = make_fake_national_forecast(session=db_session)
81-
forecast = Forecast.from_orm(forecast_sql)
81+
forecast = Forecast.model_validate(forecast_sql, from_attributes=True)
8282
_ = Forecast.to_orm(forecast)
8383

8484

0 commit comments

Comments
 (0)