|
17 | 17 | import os |
18 | 18 | import time |
19 | 19 | from contextlib import asynccontextmanager |
20 | | -from pathlib import Path |
21 | 20 | from typing import Final |
22 | 21 |
|
23 | | -import apsw |
24 | | -import apsw.bestpractice |
| 22 | +import mariadb |
25 | 23 | import uvicorn |
26 | 24 | from fastapi import FastAPI |
27 | 25 | from fastapi.middleware.cors import CORSMiddleware |
|
80 | 78 | ] |
81 | 79 |
|
82 | 80 |
|
83 | | -def _enable_best_practice(connection: apsw.Connection): |
84 | | - """Enable aspw best practice.""" |
85 | | - apsw.bestpractice.connection_wal(connection) |
86 | | - apsw.bestpractice.library_logging() |
| 81 | +def _get_database_connection() -> mariadb.Connection: |
| 82 | + """Get a MriaDB database connection.""" |
| 83 | + connection = mariadb.connect( |
| 84 | + user=os.environ["DB_USER"], |
| 85 | + password=os.environ["DB_PASS"], |
| 86 | + host="127.0.0.1", |
| 87 | + port=3306, |
| 88 | + database=os.environ["DB_DATABASE"], |
| 89 | + autocommit=True, |
| 90 | + ) |
| 91 | + return connection |
87 | 92 |
|
88 | 93 |
|
89 | 94 | @asynccontextmanager |
90 | 95 | async def lifespan(app: FastAPI): |
91 | 96 | """Load the database connection for the life of the app.s""" |
92 | | - db_path = Path(os.environ["DATABASE_PATH"]) |
93 | | - logger.info("validator database: %s", db_path) |
94 | | - app.state.connection = apsw.Connection( |
95 | | - str(db_path), flags=apsw.SQLITE_OPEN_READONLY |
96 | | - ) |
97 | | - _enable_best_practice(app.state.connection) |
| 97 | + app.state.connection = _get_database_connection() |
98 | 98 | app.state.kupo_url = os.environ["KUPO_URL"] |
99 | 99 | app.state.kupo_port = os.environ["KUPO_PORT"] |
100 | 100 | yield |
@@ -141,34 +141,36 @@ def redirect_root_to_docs(): |
141 | 141 | @app.get("/get_active_participants", tags=[TAG_STATISTICS]) |
142 | 142 | async def get_active_participants(): |
143 | 143 | """Return participants in the ITN database.""" |
| 144 | + cursor = app.state.connection.cursor() |
144 | 145 | try: |
145 | | - participants = app.state.connection.execute( |
146 | | - "select distinct address from data_points;" |
147 | | - ) |
148 | | - except apsw.SQLError as err: |
| 146 | + cursor.execute("select distinct address from data_points;") |
| 147 | + except mariadb.Error as err: |
149 | 148 | return {"error": f"{err}"} |
150 | | - data = [participant[0] for participant in participants] |
| 149 | + data = [participant[0] for participant in cursor] |
| 150 | + cursor.close() |
151 | 151 | return data |
152 | 152 |
|
153 | 153 |
|
154 | 154 | @app.get("/get_participants_counts_total", tags=[TAG_STATISTICS]) |
155 | 155 | async def get_participants_counts_total(): |
156 | 156 | """Return participants total counts.""" |
| 157 | + cursor = app.state.connection.cursor() |
157 | 158 | try: |
158 | | - participants_count_total = app.state.connection.execute( |
| 159 | + cursor.execute( |
159 | 160 | "select count(*) as count, address from data_points group by address order by count desc;" |
160 | 161 | ) |
161 | | - except apsw.SQLError as err: |
| 162 | + except mariadb.Error as err: |
162 | 163 | return {"error": f"{err}"} |
163 | | - return participants_count_total |
| 164 | + res = list(cursor) |
| 165 | + cursor.close() |
| 166 | + return res |
164 | 167 |
|
165 | 168 |
|
166 | 169 | @app.get("/get_participants_counts_day", tags=[TAG_STATISTICS]) |
167 | 170 | async def get_participants_counts_day( |
168 | 171 | date_start: str = "1970-01-01", date_end: str = "1970-01-03" |
169 | 172 | ): |
170 | 173 | """Return participants in ITN.""" |
171 | | - |
172 | 174 | report = reports.get_participants_counts_date_range(app, date_start, date_end) |
173 | 175 | return report |
174 | 176 |
|
@@ -231,28 +233,33 @@ async def get_itn_participants() -> str: |
231 | 233 | @app.get("/online_collectors", tags=[TAG_HTMX], response_class=HTMLResponse) |
232 | 234 | async def get_online_collectors() -> str: |
233 | 235 | """Return ITN aliases and collector counts.""" |
| 236 | + cursor = app.state.connection.cursor() |
234 | 237 | try: |
235 | | - participants_count = app.state.connection.execute( |
| 238 | + cursor.execute( |
236 | 239 | """SELECT address, COUNT(*) AS total_count, |
237 | | - SUM(CASE WHEN datetime(date_time) >= datetime('now', '-24 hours') |
| 240 | + SUM(CASE WHEN date_time >= (SELECT DATE_SUB(NOW(), INTERVAL 1 DAY)) |
238 | 241 | THEN 1 ELSE 0 END) AS count_24hr |
239 | 242 | FROM data_points |
240 | 243 | GROUP BY address ORDER BY total_count DESC; |
241 | 244 | """ |
242 | 245 | ) |
243 | | - except apsw.SQLError: |
| 246 | + except mariadb.Error: |
244 | 247 | return "zero collectors online" |
245 | 248 |
|
| 249 | + participants_count = list(cursor) |
| 250 | + |
246 | 251 | try: |
247 | | - feed_count = app.state.connection.execute( |
| 252 | + cursor.execute( |
248 | 253 | """SELECT distinct feed_id |
249 | 254 | from data_points |
250 | | - where datetime(date_time) >= datetime('now', '-48 hours'); |
| 255 | + where date_time >= (SELECT DATE_SUB(NOW(), INTERVAL 1 DAY)); |
251 | 256 | """ |
252 | 257 | ) |
253 | | - except apsw.SQLError: |
| 258 | + except mariadb.Error: |
254 | 259 | return "zero collectors online" |
255 | 260 |
|
| 261 | + feed_count = list(cursor) |
| 262 | + |
256 | 263 | no_feeds = len(list(feed_count)) |
257 | 264 |
|
258 | 265 | # FIXME: These can all be combined better, e.g. into a dataclass or |
@@ -308,13 +315,13 @@ async def get_locations_map_hx(): |
308 | 315 | @app.get("/count_active_participants", tags=[TAG_HTMX], response_class=HTMLResponse) |
309 | 316 | async def count_active_participants(): |
310 | 317 | """Count active participants.""" |
| 318 | + cursor = app.state.connection.cursor() |
311 | 319 | try: |
312 | | - participants = app.state.connection.execute( |
313 | | - "select count(distinct address) as count from data_points;" |
314 | | - ) |
315 | | - except apsw.SQLError as err: |
| 320 | + cursor.execute("select count(distinct address) as count from data_points;") |
| 321 | + except mariadb.Error as err: |
316 | 322 | return {"error": f"{err}"} |
317 | | - data = list(participants) |
| 323 | + data = list(cursor) |
| 324 | + cursor.close() |
318 | 325 | return f"{data[0][0]}" |
319 | 326 |
|
320 | 327 |
|
|
0 commit comments