Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions github/avatar_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import asyncio
import time

from sqlalchemy import Column, MetaData, Table, Text
from sqlalchemy.engine.base import Engine
Expand All @@ -15,6 +16,8 @@
class AvatarManager:
bot: "GitHubBot"
_avatars: dict[str, ContentURI]
_etag: dict[str, Optional[str]]
_fetched_at: dict[str, int]
_db: DBManager
_lock: asyncio.Lock

Expand All @@ -23,26 +26,50 @@ def __init__(self, bot: "GitHubBot") -> None:
self._db = bot.db
self._lock = asyncio.Lock()
self._avatars = {}
self._etag = {}
self._fetched_at = {}

async def load_db(self) -> None:
self._avatars = {
avatar.url: ContentURI(avatar.mxc) for avatar in await self._db.get_avatars()
}
rows = await self._db.get_avatars()
self._avatars = {avatar.url: ContentURI(avatar.mxc) for avatar in rows}
self._etag = {avatar.url: avatar.etag for avatar in rows}
self._fetched_at = {avatar.url: int(avatar.fetched_at or 0) for avatar in rows}

async def get_mxc(self, url: str) -> ContentURI:
try:
now = int(time.time())
# 5 min TTL
if url in self._avatars and (now - self._fetched_at.get(url, 0)) < 300:
return self._avatars[url]
except KeyError:
pass
async with self.bot.http.get(url) as resp:

headers: dict[str, str] = {}
etag = self._etag.get(url)
if etag:
headers["If-None-Match"] = etag

async with self.bot.http.get(url, headers=headers) as resp:
if resp.status == 304 and url in self._avatars:
# Unchanged: bump fetched_at and persist
self._fetched_at[url] = now
await self._db.put_avatar(
url,
self._avatars[url],
etag=self._etag.get(url),
fetched_at=now,
)
return self._avatars[url]

resp.raise_for_status()
data = await resp.read()
new_etag = resp.headers.get("ETag")

async with self._lock:
try:
# Race guard with same TTL inside the lock
if url in self._avatars and (now - self._fetched_at.get(url, 0)) < 300:
return self._avatars[url]
except KeyError:
pass

mxc = await self.bot.client.upload_media(data)
self._avatars[url] = mxc
await self._db.put_avatar(url, mxc)
return mxc
self._etag[url] = new_etag
self._fetched_at[url] = now
await self._db.put_avatar(url, mxc, etag=new_etag, fetched_at=now)
return mxc
27 changes: 23 additions & 4 deletions github/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from typing import Optional
import hashlib
import hmac
Expand Down Expand Up @@ -47,16 +48,22 @@ def from_row(cls, row: Record | None) -> Optional["Client"]:
class Avatar:
url: str
mxc: ContentURI
etag: Optional[str] = None
fetched_at: int = 0

@classmethod
def from_row(cls, row: Record | None) -> Optional["Avatar"]:
if not row:
return None
url = row["url"]
mxc = row["mxc"]
etag = row.get("etag")
fetched_at = int(row.get("fetched_at") or 0)
return cls(
url=url,
mxc=mxc,
etag=etag,
fetched_at=fetched_at,
)


Expand Down Expand Up @@ -145,17 +152,29 @@ async def delete_client(self, user_id: UserID) -> None:
)

async def get_avatars(self) -> list[Avatar]:
rows = await self.db.fetch("SELECT url, mxc FROM avatar")
rows = await self.db.fetch("SELECT url, mxc, etag, fetched_at FROM avatar")
return [Avatar.from_row(row) for row in rows]

async def put_avatar(self, url: str, mxc: ContentURI) -> None:
async def put_avatar(
self,
url: str,
mxc: ContentURI,
*,
etag: Optional[str] = None,
fetched_at: Optional[int] = None,
) -> None:
await self.db.execute(
"""
INSERT INTO avatar (url, mxc) VALUES ($1, $2)
ON CONFLICT (url) DO NOTHING
INSERT INTO avatar (url, mxc, etag, fetched_at) VALUES ($1, $2, $3, $4)
ON CONFLICT (url) DO UPDATE SET
mxc = excluded.mxc,
etag = excluded.etag,
fetched_at = excluded.fetched_at
""",
url,
mxc,
etag,
fetched_at,
)

async def get_webhook_by_id(self, id: uuid.UUID) -> WebhookInfo | None:
Expand Down
16 changes: 13 additions & 3 deletions github/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,23 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from mautrix.util.async_db import Connection, Scheme, UpgradeTable

upgrade_table = UpgradeTable()


@upgrade_table.register(description="Latest revision", upgrades_to=1)
@upgrade_table.register(description="Initial schema", upgrades_to=1)
async def upgrade_latest(conn: Connection, scheme: Scheme) -> None:
needs_migration = False
if await conn.table_exists("webhook"):
needs_migration = True
await conn.execute("ALTER TABLE webhook RENAME TO webhook_old;")
await conn.execute("ALTER TABLE client RENAME TO client_old;")
await conn.execute("ALTER TABLE matrix_message RENAME TO matrix_message_old;")

await conn.execute(
f"""CREATE TABLE client (
"""CREATE TABLE client (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
PRIMARY KEY (user_id)
Expand Down Expand Up @@ -60,13 +62,21 @@ async def upgrade_latest(conn: Connection, scheme: Scheme) -> None:
PRIMARY KEY (url)
)"""
)

if needs_migration:
await migrate_legacy_to_v1(conn)


async def migrate_legacy_to_v1(conn: Connection) -> None:
await conn.execute("INSERT INTO client (user_id, token) SELECT user_id, token FROM client_old")
await conn.execute(
"INSERT INTO matrix_message (message_id, room_id, event_id) SELECT message_id, room_id, event_id FROM matrix_message_old"
"INSERT INTO matrix_message (message_id, room_id, event_id) "
"SELECT message_id, room_id, event_id FROM matrix_message_old"
)
await conn.execute("CREATE TABLE needs_post_migration(noop INTEGER PRIMARY KEY)")


@upgrade_table.register(description="Add etag and fetched_at to avatar table", upgrades_to=2)
async def upgrade_v2(conn: Connection) -> None:
await conn.execute("ALTER TABLE avatar ADD COLUMN etag TEXT")
await conn.execute("ALTER TABLE avatar ADD COLUMN fetched_at INTEGER")