diff --git a/github/avatar_manager.py b/github/avatar_manager.py index 5ef699e..cc9b418 100644 --- a/github/avatar_manager.py +++ b/github/avatar_manager.py @@ -1,5 +1,6 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import asyncio +import time from sqlalchemy import Column, MetaData, Table, Text from sqlalchemy.engine.base import Engine @@ -15,6 +16,8 @@ class AvatarManager: bot: "GitHubBot" _avatars: dict[str, ContentURI] + _etag: dict[str, Optional[str]] + _fetched_at: dict[str, int] _db: DBManager _lock: asyncio.Lock @@ -23,26 +26,50 @@ def __init__(self, bot: "GitHubBot") -> None: self._db = bot.db self._lock = asyncio.Lock() self._avatars = {} + self._etag = {} + self._fetched_at = {} async def load_db(self) -> None: - self._avatars = { - avatar.url: ContentURI(avatar.mxc) for avatar in await self._db.get_avatars() - } + rows = await self._db.get_avatars() + self._avatars = {avatar.url: ContentURI(avatar.mxc) for avatar in rows} + self._etag = {avatar.url: avatar.etag for avatar in rows} + self._fetched_at = {avatar.url: int(avatar.fetched_at or 0) for avatar in rows} async def get_mxc(self, url: str) -> ContentURI: - try: + now = int(time.time()) + # 5 min TTL + if url in self._avatars and (now - self._fetched_at.get(url, 0)) < 300: return self._avatars[url] - except KeyError: - pass - async with self.bot.http.get(url) as resp: + + headers: dict[str, str] = {} + etag = self._etag.get(url) + if etag: + headers["If-None-Match"] = etag + + async with self.bot.http.get(url, headers=headers) as resp: + if resp.status == 304 and url in self._avatars: + # Unchanged: bump fetched_at and persist + self._fetched_at[url] = now + await self._db.put_avatar( + url, + self._avatars[url], + etag=self._etag.get(url), + fetched_at=now, + ) + return self._avatars[url] + resp.raise_for_status() data = await resp.read() + new_etag = resp.headers.get("ETag") + async with self._lock: - try: + # Race guard with same TTL inside the lock + if url in self._avatars and (now - self._fetched_at.get(url, 0)) < 300: return self._avatars[url] - except KeyError: - pass + mxc = await self.bot.client.upload_media(data) self._avatars[url] = mxc - await self._db.put_avatar(url, mxc) - return mxc + self._etag[url] = new_etag + self._fetched_at[url] = now + await self._db.put_avatar(url, mxc, etag=new_etag, fetched_at=now) + return mxc diff --git a/github/db.py b/github/db.py index 40345c8..b1d72f5 100644 --- a/github/db.py +++ b/github/db.py @@ -14,6 +14,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . + from typing import Optional import hashlib import hmac @@ -47,6 +48,8 @@ def from_row(cls, row: Record | None) -> Optional["Client"]: class Avatar: url: str mxc: ContentURI + etag: Optional[str] = None + fetched_at: int = 0 @classmethod def from_row(cls, row: Record | None) -> Optional["Avatar"]: @@ -54,9 +57,13 @@ def from_row(cls, row: Record | None) -> Optional["Avatar"]: return None url = row["url"] mxc = row["mxc"] + etag = row.get("etag") + fetched_at = int(row.get("fetched_at") or 0) return cls( url=url, mxc=mxc, + etag=etag, + fetched_at=fetched_at, ) @@ -145,17 +152,29 @@ async def delete_client(self, user_id: UserID) -> None: ) async def get_avatars(self) -> list[Avatar]: - rows = await self.db.fetch("SELECT url, mxc FROM avatar") + rows = await self.db.fetch("SELECT url, mxc, etag, fetched_at FROM avatar") return [Avatar.from_row(row) for row in rows] - async def put_avatar(self, url: str, mxc: ContentURI) -> None: + async def put_avatar( + self, + url: str, + mxc: ContentURI, + *, + etag: Optional[str] = None, + fetched_at: Optional[int] = None, + ) -> None: await self.db.execute( """ - INSERT INTO avatar (url, mxc) VALUES ($1, $2) - ON CONFLICT (url) DO NOTHING + INSERT INTO avatar (url, mxc, etag, fetched_at) VALUES ($1, $2, $3, $4) + ON CONFLICT (url) DO UPDATE SET + mxc = excluded.mxc, + etag = excluded.etag, + fetched_at = excluded.fetched_at """, url, mxc, + etag, + fetched_at, ) async def get_webhook_by_id(self, id: uuid.UUID) -> WebhookInfo | None: diff --git a/github/migrations.py b/github/migrations.py index 50ad2c6..055ccc6 100644 --- a/github/migrations.py +++ b/github/migrations.py @@ -13,12 +13,13 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . + from mautrix.util.async_db import Connection, Scheme, UpgradeTable upgrade_table = UpgradeTable() -@upgrade_table.register(description="Latest revision", upgrades_to=1) +@upgrade_table.register(description="Initial schema", upgrades_to=1) async def upgrade_latest(conn: Connection, scheme: Scheme) -> None: needs_migration = False if await conn.table_exists("webhook"): @@ -26,8 +27,9 @@ async def upgrade_latest(conn: Connection, scheme: Scheme) -> None: await conn.execute("ALTER TABLE webhook RENAME TO webhook_old;") await conn.execute("ALTER TABLE client RENAME TO client_old;") await conn.execute("ALTER TABLE matrix_message RENAME TO matrix_message_old;") + await conn.execute( - f"""CREATE TABLE client ( + """CREATE TABLE client ( user_id TEXT NOT NULL, token TEXT NOT NULL, PRIMARY KEY (user_id) @@ -60,6 +62,7 @@ async def upgrade_latest(conn: Connection, scheme: Scheme) -> None: PRIMARY KEY (url) )""" ) + if needs_migration: await migrate_legacy_to_v1(conn) @@ -67,6 +70,13 @@ async def upgrade_latest(conn: Connection, scheme: Scheme) -> None: async def migrate_legacy_to_v1(conn: Connection) -> None: await conn.execute("INSERT INTO client (user_id, token) SELECT user_id, token FROM client_old") await conn.execute( - "INSERT INTO matrix_message (message_id, room_id, event_id) SELECT message_id, room_id, event_id FROM matrix_message_old" + "INSERT INTO matrix_message (message_id, room_id, event_id) " + "SELECT message_id, room_id, event_id FROM matrix_message_old" ) await conn.execute("CREATE TABLE needs_post_migration(noop INTEGER PRIMARY KEY)") + + +@upgrade_table.register(description="Add etag and fetched_at to avatar table", upgrades_to=2) +async def upgrade_v2(conn: Connection) -> None: + await conn.execute("ALTER TABLE avatar ADD COLUMN etag TEXT") + await conn.execute("ALTER TABLE avatar ADD COLUMN fetched_at INTEGER")