Skip to content

Commit 3e14890

Browse files
Merge pull request #14 from openclimatefix/add-read-all-locations
add method to read all lcations + tests
2 parents ef2fd40 + 617dcce commit 3e14890

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

nowcasting_datamodel/read/read.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,33 @@ def get_location(session: Session, gsp_id: int) -> LocationSQL:
175175
return location
176176

177177

178+
def get_all_location(session: Session, gsp_ids: List[int] = None) -> List[LocationSQL]:
179+
"""
180+
Get all location object from gsp id
181+
182+
:param session: database session
183+
:param gsp_ids: list of gsp id of the location
184+
185+
return: List of GSP locations
186+
187+
"""
188+
189+
# start main query
190+
query = session.query(LocationSQL)
191+
query = query.distinct(LocationSQL.gsp_id)
192+
193+
# filter on gsp_id
194+
if gsp_ids is not None:
195+
query = query.filter(LocationSQL.gsp_id.in_(gsp_ids))
196+
197+
query = query.order_by(LocationSQL.gsp_id)
198+
199+
# get all results
200+
locations = query.all()
201+
202+
return locations
203+
204+
178205
def get_model(session: Session, name: str, version: str) -> MLModelSQL:
179206
"""
180207
Get model object from name and version

tests/test_read.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from nowcasting_datamodel.models import Forecast, ForecastValue, LocationSQL, MLModel, PVSystem
99
from nowcasting_datamodel.read.read import (
1010
get_all_gsp_ids_latest_forecast,
11+
get_all_location,
1112
get_forecast_values,
1213
get_latest_forecast,
1314
get_latest_national_forecast,
@@ -19,6 +20,15 @@
1920
logger = logging.getLogger(__name__)
2021

2122

23+
def test_get_all_location(db_session):
24+
25+
db_session.add(LocationSQL(label="GSP_1", gsp_id=1))
26+
db_session.add(LocationSQL(label="GSP_2", gsp_id=2))
27+
28+
locations = get_all_location(session=db_session)
29+
assert len(locations) == 2
30+
31+
2232
def test_get_model(db_session):
2333

2434
model_read_1 = get_model(session=db_session, name="test_name", version="9.9.9")

0 commit comments

Comments
 (0)