diff --git a/github/avatar_manager.py b/github/avatar_manager.py
index 5ef699e..cc9b418 100644
--- a/github/avatar_manager.py
+++ b/github/avatar_manager.py
@@ -1,5 +1,6 @@
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
import asyncio
+import time
from sqlalchemy import Column, MetaData, Table, Text
from sqlalchemy.engine.base import Engine
@@ -15,6 +16,8 @@
class AvatarManager:
bot: "GitHubBot"
_avatars: dict[str, ContentURI]
+ _etag: dict[str, Optional[str]]
+ _fetched_at: dict[str, int]
_db: DBManager
_lock: asyncio.Lock
@@ -23,26 +26,50 @@ def __init__(self, bot: "GitHubBot") -> None:
self._db = bot.db
self._lock = asyncio.Lock()
self._avatars = {}
+ self._etag = {}
+ self._fetched_at = {}
async def load_db(self) -> None:
- self._avatars = {
- avatar.url: ContentURI(avatar.mxc) for avatar in await self._db.get_avatars()
- }
+ rows = await self._db.get_avatars()
+ self._avatars = {avatar.url: ContentURI(avatar.mxc) for avatar in rows}
+ self._etag = {avatar.url: avatar.etag for avatar in rows}
+ self._fetched_at = {avatar.url: int(avatar.fetched_at or 0) for avatar in rows}
async def get_mxc(self, url: str) -> ContentURI:
- try:
+ now = int(time.time())
+ # 5 min TTL
+ if url in self._avatars and (now - self._fetched_at.get(url, 0)) < 300:
return self._avatars[url]
- except KeyError:
- pass
- async with self.bot.http.get(url) as resp:
+
+ headers: dict[str, str] = {}
+ etag = self._etag.get(url)
+ if etag:
+ headers["If-None-Match"] = etag
+
+ async with self.bot.http.get(url, headers=headers) as resp:
+ if resp.status == 304 and url in self._avatars:
+ # Unchanged: bump fetched_at and persist
+ self._fetched_at[url] = now
+ await self._db.put_avatar(
+ url,
+ self._avatars[url],
+ etag=self._etag.get(url),
+ fetched_at=now,
+ )
+ return self._avatars[url]
+
resp.raise_for_status()
data = await resp.read()
+ new_etag = resp.headers.get("ETag")
+
async with self._lock:
- try:
+ # Race guard with same TTL inside the lock
+ if url in self._avatars and (now - self._fetched_at.get(url, 0)) < 300:
return self._avatars[url]
- except KeyError:
- pass
+
mxc = await self.bot.client.upload_media(data)
self._avatars[url] = mxc
- await self._db.put_avatar(url, mxc)
- return mxc
+ self._etag[url] = new_etag
+ self._fetched_at[url] = now
+ await self._db.put_avatar(url, mxc, etag=new_etag, fetched_at=now)
+ return mxc
diff --git a/github/db.py b/github/db.py
index 40345c8..b1d72f5 100644
--- a/github/db.py
+++ b/github/db.py
@@ -14,6 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+
from typing import Optional
import hashlib
import hmac
@@ -47,6 +48,8 @@ def from_row(cls, row: Record | None) -> Optional["Client"]:
class Avatar:
url: str
mxc: ContentURI
+ etag: Optional[str] = None
+ fetched_at: int = 0
@classmethod
def from_row(cls, row: Record | None) -> Optional["Avatar"]:
@@ -54,9 +57,13 @@ def from_row(cls, row: Record | None) -> Optional["Avatar"]:
return None
url = row["url"]
mxc = row["mxc"]
+ etag = row.get("etag")
+ fetched_at = int(row.get("fetched_at") or 0)
return cls(
url=url,
mxc=mxc,
+ etag=etag,
+ fetched_at=fetched_at,
)
@@ -145,17 +152,29 @@ async def delete_client(self, user_id: UserID) -> None:
)
async def get_avatars(self) -> list[Avatar]:
- rows = await self.db.fetch("SELECT url, mxc FROM avatar")
+ rows = await self.db.fetch("SELECT url, mxc, etag, fetched_at FROM avatar")
return [Avatar.from_row(row) for row in rows]
- async def put_avatar(self, url: str, mxc: ContentURI) -> None:
+ async def put_avatar(
+ self,
+ url: str,
+ mxc: ContentURI,
+ *,
+ etag: Optional[str] = None,
+ fetched_at: Optional[int] = None,
+ ) -> None:
await self.db.execute(
"""
- INSERT INTO avatar (url, mxc) VALUES ($1, $2)
- ON CONFLICT (url) DO NOTHING
+ INSERT INTO avatar (url, mxc, etag, fetched_at) VALUES ($1, $2, $3, $4)
+ ON CONFLICT (url) DO UPDATE SET
+ mxc = excluded.mxc,
+ etag = excluded.etag,
+ fetched_at = excluded.fetched_at
""",
url,
mxc,
+ etag,
+ fetched_at,
)
async def get_webhook_by_id(self, id: uuid.UUID) -> WebhookInfo | None:
diff --git a/github/migrations.py b/github/migrations.py
index 50ad2c6..055ccc6 100644
--- a/github/migrations.py
+++ b/github/migrations.py
@@ -13,12 +13,13 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+
from mautrix.util.async_db import Connection, Scheme, UpgradeTable
upgrade_table = UpgradeTable()
-@upgrade_table.register(description="Latest revision", upgrades_to=1)
+@upgrade_table.register(description="Initial schema", upgrades_to=1)
async def upgrade_latest(conn: Connection, scheme: Scheme) -> None:
needs_migration = False
if await conn.table_exists("webhook"):
@@ -26,8 +27,9 @@ async def upgrade_latest(conn: Connection, scheme: Scheme) -> None:
await conn.execute("ALTER TABLE webhook RENAME TO webhook_old;")
await conn.execute("ALTER TABLE client RENAME TO client_old;")
await conn.execute("ALTER TABLE matrix_message RENAME TO matrix_message_old;")
+
await conn.execute(
- f"""CREATE TABLE client (
+ """CREATE TABLE client (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
PRIMARY KEY (user_id)
@@ -60,6 +62,7 @@ async def upgrade_latest(conn: Connection, scheme: Scheme) -> None:
PRIMARY KEY (url)
)"""
)
+
if needs_migration:
await migrate_legacy_to_v1(conn)
@@ -67,6 +70,13 @@ async def upgrade_latest(conn: Connection, scheme: Scheme) -> None:
async def migrate_legacy_to_v1(conn: Connection) -> None:
await conn.execute("INSERT INTO client (user_id, token) SELECT user_id, token FROM client_old")
await conn.execute(
- "INSERT INTO matrix_message (message_id, room_id, event_id) SELECT message_id, room_id, event_id FROM matrix_message_old"
+ "INSERT INTO matrix_message (message_id, room_id, event_id) "
+ "SELECT message_id, room_id, event_id FROM matrix_message_old"
)
await conn.execute("CREATE TABLE needs_post_migration(noop INTEGER PRIMARY KEY)")
+
+
+@upgrade_table.register(description="Add etag and fetched_at to avatar table", upgrades_to=2)
+async def upgrade_v2(conn: Connection) -> None:
+ await conn.execute("ALTER TABLE avatar ADD COLUMN etag TEXT")
+ await conn.execute("ALTER TABLE avatar ADD COLUMN fetched_at INTEGER")