From 579ccc1fba97b9140c99312171bd3e01f11fb3ca Mon Sep 17 00:00:00 2001 From: nivaldoh Date: Tue, 27 Jan 2026 16:37:32 -0300 Subject: [PATCH 1/9] speed up team validation --- metamon/backend/team_prediction/validate.py | 155 +++++++++++++++++++- tools/persistent_showdown_validator.js | 110 ++++++++++++++ 2 files changed, 258 insertions(+), 7 deletions(-) create mode 100644 tools/persistent_showdown_validator.js diff --git a/metamon/backend/team_prediction/validate.py b/metamon/backend/team_prediction/validate.py index 3475d51d8..bc3315baf 100644 --- a/metamon/backend/team_prediction/validate.py +++ b/metamon/backend/team_prediction/validate.py @@ -1,9 +1,12 @@ -import subprocess -import random -from typing import List, Tuple -import gc import argparse +import atexit +import json import os +from pathlib import Path +import random +import shutil +import subprocess +from typing import List, Optional import tqdm from poke_env.teambuilder import ConstantTeambuilder @@ -19,13 +22,151 @@ ) from metamon.tokenizer import get_tokenizer +_REPO_ROOT = Path(__file__).resolve().parents[3] +_PERSISTENT_VALIDATOR = None +_PERSISTENT_VALIDATOR_DISABLED = False + + +def _candidate_node_cwds(repo_root: Path) -> List[Path]: + return [repo_root, repo_root / "server" / "pokemon-showdown"] + + +def _resolve_node_cwd(repo_root: Path) -> Path: + for cwd in _candidate_node_cwds(repo_root): + if (cwd / "node_modules" / "pokemon-showdown").exists(): + return cwd + if (cwd / "node_modules" / ".bin" / "pokemon-showdown").exists(): + return cwd + return repo_root + + +def _find_showdown_bin(repo_root: Path) -> Optional[str]: + local_bins = [ + repo_root / "node_modules" / ".bin" / "pokemon-showdown", + repo_root + / "server" + / "pokemon-showdown" + / "node_modules" + / ".bin" + / "pokemon-showdown", + ] + for bin_path in local_bins: + if bin_path.exists(): + return str(bin_path) + return shutil.which("pokemon-showdown") + + +def _resolve_showdown_validate_cmd(format_id: str, cmd: Optional[List[str]]) -> List[str]: + if cmd is not None: + return cmd + [format_id] + showdown_bin = _find_showdown_bin(_REPO_ROOT) + if showdown_bin: + return [showdown_bin, "validate-team", format_id] + return ["npx", "pokemon-showdown", "validate-team", format_id] + + +class PersistentShowdownValidator: + def __init__(self, repo_root: Path): + self._script_path = repo_root / "tools" / "persistent_showdown_validator.js" + if not self._script_path.exists(): + raise FileNotFoundError(f"Missing validator script at {self._script_path}") + self._cwd = _resolve_node_cwd(repo_root) + self._proc = self._start_process() + if not self._ping(): + self.close() + raise RuntimeError("Persistent validator failed to start") + + def _start_process(self) -> subprocess.Popen: + return subprocess.Popen( + ["node", str(self._script_path)], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + text=True, + cwd=str(self._cwd), + bufsize=1, + ) + + def _send(self, payload: dict) -> Optional[dict]: + if self._proc.poll() is not None: + return None + if self._proc.stdin is None or self._proc.stdout is None: + return None + try: + self._proc.stdin.write(json.dumps(payload) + "\n") + self._proc.stdin.flush() + except BrokenPipeError: + return None + line = self._proc.stdout.readline() + if not line: + return None + try: + return json.loads(line) + except json.JSONDecodeError: + return None + + def _ping(self) -> bool: + response = self._send({"format": "gen1ou", "team": ""}) + return response is not None + + def validate(self, team_str: str, format_id: str) -> tuple[bool, List[str]]: + response = self._send({"format": format_id, "team": team_str}) + if response is None: + raise RuntimeError("Validator process is not responding") + ok = bool(response.get("ok")) + errors = response.get("errors") or [] + return ok, [str(err) for err in errors] + + def close(self) -> None: + if self._proc is None: + return + if self._proc.poll() is None: + try: + if self._proc.stdin is not None: + self._proc.stdin.close() + self._proc.terminate() + self._proc.wait(timeout=1) + except subprocess.TimeoutExpired: + self._proc.kill() + self._proc = None + + +def _get_persistent_validator() -> Optional[PersistentShowdownValidator]: + global _PERSISTENT_VALIDATOR, _PERSISTENT_VALIDATOR_DISABLED + if _PERSISTENT_VALIDATOR_DISABLED: + return None + if _PERSISTENT_VALIDATOR is None: + try: + _PERSISTENT_VALIDATOR = PersistentShowdownValidator(_REPO_ROOT) + atexit.register(_PERSISTENT_VALIDATOR.close) + except Exception as exc: # pragma: no cover - best-effort optimization + print(f"Persistent validator unavailable, falling back to CLI: {exc}") + _PERSISTENT_VALIDATOR_DISABLED = True + return None + return _PERSISTENT_VALIDATOR + def validate_showdown_team( team_str: str, format_id: str = "gen1ou", - cmd: List[str] = ["npx", "pokemon-showdown", "validate-team"], -) -> Tuple[bool, List[str]]: - full_cmd = cmd + [format_id] + cmd: Optional[List[str]] = None, +) -> bool: + validator = _get_persistent_validator() + if validator is not None: + global _PERSISTENT_VALIDATOR_DISABLED + try: + ok, errors = validator.validate(team_str, format_id) + except Exception as exc: # pragma: no cover - best-effort optimization + validator.close() + _PERSISTENT_VALIDATOR_DISABLED = True + print(f"Persistent validator failed, falling back to CLI: {exc}") + else: + if ok: + return True + print(errors) + return False + + full_cmd = _resolve_showdown_validate_cmd(format_id, cmd) proc = subprocess.run(full_cmd, input=team_str, text=True, capture_output=True) diff --git a/tools/persistent_showdown_validator.js b/tools/persistent_showdown_validator.js new file mode 100644 index 000000000..e83c73cb3 --- /dev/null +++ b/tools/persistent_showdown_validator.js @@ -0,0 +1,110 @@ +#!/usr/bin/env node + +/** + * Persistent Pokemon Showdown team validator. + * + * Protocol: + * - Read JSON lines on stdin: {"format": "gen1ou", "team": ""} + * - Write JSON lines on stdout: {"ok": true/false, "errors": ["..."]} + */ + +const readline = require("readline"); + +let TeamValidator; +let Teams; + +try { + ({ TeamValidator, Teams } = require("pokemon-showdown")); +} catch (errPrimary) { + try { + ({ TeamValidator } = require("pokemon-showdown/dist/sim/team-validator")); + } catch (errSecondary) { + const message = + "Unable to load pokemon-showdown. Install it locally with npm install."; + console.error(message); + console.error(String(errPrimary)); + console.error(String(errSecondary)); + process.exit(1); + } +} + +if (!Teams) { + try { + const teamsModule = require("pokemon-showdown/dist/sim/teams"); + Teams = teamsModule.Teams || teamsModule; + } catch (err) { + console.error("Unable to load pokemon-showdown Teams module."); + console.error(String(err)); + process.exit(1); + } +} + +const validatorsByFormat = new Map(); + +function getValidator(format) { + if (!validatorsByFormat.has(format)) { + validatorsByFormat.set(format, TeamValidator.get(format)); + } + return validatorsByFormat.get(format); +} + +function normalizeErrors(result) { + if (!result) return []; + if (Array.isArray(result)) return result.map((e) => String(e)); + return [String(result)]; +} + +function respond(payload) { + process.stdout.write(`${JSON.stringify(payload)}\n`); +} + +const rl = readline.createInterface({ + input: process.stdin, + crlfDelay: Infinity, +}); + +rl.on("line", (line) => { + const trimmed = line.trim(); + if (!trimmed) return; + + let req; + try { + req = JSON.parse(trimmed); + } catch (err) { + respond({ ok: false, errors: ["Invalid JSON input."] }); + return; + } + + const format = req.format; + const team = req.team; + + if (!format || typeof team !== "string") { + respond({ ok: false, errors: ["Input must include format and team string."] }); + return; + } + + let validator; + try { + validator = getValidator(format); + } catch (err) { + respond({ ok: false, errors: [String(err)] }); + return; + } + + try { + const parsedTeam = Teams.import(team); + if (!parsedTeam) { + respond({ ok: false, errors: ["Invalid team data"] }); + return; + } + const result = validator.validateTeam(parsedTeam); + const errors = normalizeErrors(result); + respond({ ok: errors.length === 0, errors }); + } catch (err) { + respond({ ok: false, errors: [String(err)] }); + } +}); + +rl.on("close", () => { + process.exit(0); +}); From ded502994e668633ca49a9d760c8df3e4e3b3fd9 Mon Sep 17 00:00:00 2001 From: nivaldoh Date: Thu, 29 Jan 2026 18:50:22 -0300 Subject: [PATCH 2/9] enable baseline filtering + chaos files --- .../usage_stats/stat_scraper.py | 148 +++++++++++++----- 1 file changed, 105 insertions(+), 43 deletions(-) diff --git a/metamon/backend/team_prediction/usage_stats/stat_scraper.py b/metamon/backend/team_prediction/usage_stats/stat_scraper.py index 64a5f89a2..ffe9f7115 100644 --- a/metamon/backend/team_prediction/usage_stats/stat_scraper.py +++ b/metamon/backend/team_prediction/usage_stats/stat_scraper.py @@ -1,10 +1,11 @@ -import os -import argparse -import asyncio -import aiohttp -import aiofiles -from bs4 import BeautifulSoup -from urllib.parse import urljoin +import os +import re +import argparse +import asyncio +import aiohttp +import aiofiles +from bs4 import BeautifulSoup +from urllib.parse import urljoin base_url = "https://www.smogon.com/stats/" @@ -23,18 +24,61 @@ default=2024, help="End year for scraping (YYYY) (exclusive)", ) -parser.add_argument( - "--save_dir", - type=str, - default="./stats", - help="Local directory to save the scraped files", -) -args = parser.parse_args() - - -def ensure_dir(file_path): - if not os.path.exists(file_path): - os.makedirs(file_path) +parser.add_argument( + "--save_dir", + type=str, + default="./stats", + help="Local directory to save the scraped files", +) +parser.add_argument( + "--baselines", + type=str, + default="", + help=( + "Comma-separated baselines to keep (e.g. '0,1500,1695,1825'). " + "Empty = keep all." + ), +) +parser.add_argument( + "--min_baseline", + type=float, + default=None, + help="If set, only download files with baseline >= this value.", +) +parser.add_argument( + "--include_chaos", + action="store_true", + help="Download chaos/ JSON files (includes info.cutoff metadata).", +) +args = parser.parse_args() + + +BASELINE_RE = re.compile(r"-(\d+(?:\.\d+)?)\.(txt|json)$") + +SKIP_DIRS = {"monotype", "metagame"} +if not args.include_chaos: + SKIP_DIRS.add("chaos") + +allowed_baselines = None +if args.baselines.strip(): + allowed_baselines = { + float(x.strip()) for x in args.baselines.split(",") if x.strip() + } + + +def extract_baseline(href: str): + m = BASELINE_RE.search(href) + if m: + return float(m.group(1)) + if href.endswith(".txt") or href.endswith(".json"): + # Smogon convention: no explicit baseline means 1500. + return 1500.0 + return None + + +def ensure_dir(file_path): + if not os.path.exists(file_path): + os.makedirs(file_path) async def save_text_file(session, url, local_path): @@ -81,30 +125,48 @@ async def scrape(session, url, local_dir): soup = BeautifulSoup(text, "html.parser") tasks = [] - for link in soup.find_all("a"): - href = link.get("href") - if "chaos" in href or "monotype" in href or "metagame" in href: - continue - if href and not href.startswith("?"): - href_full = urljoin(url, href) - local_path = os.path.join(local_dir, href) - - if href.endswith("/") and href != "../": # It's a directory - ensure_dir(local_path) - task = asyncio.create_task( - scrape(session, href_full, local_path) - ) - tasks.append(task) - elif href.endswith(".txt") or href.endswith( - ".json" - ): # It's a txt file - print(f"Downloading {href_full} to {local_path}") - task = asyncio.create_task( - save_text_file(session, href_full, local_path) - ) - tasks.append(task) - - await asyncio.gather(*tasks) + for link in soup.find_all("a"): + href = link.get("href") + if not href or href.startswith("?"): + continue + if href.endswith("/"): + if href == "../": + continue + dirname = href.rstrip("/") + if dirname in SKIP_DIRS: + continue + href_full = urljoin(url, href) + local_path = os.path.join(local_dir, href) + ensure_dir(local_path) + task = asyncio.create_task( + scrape(session, href_full, local_path) + ) + tasks.append(task) + continue + + if href.endswith(".txt") or href.endswith(".json"): + baseline = extract_baseline(href) + if ( + allowed_baselines is not None + and baseline is not None + and baseline not in allowed_baselines + ): + continue + if ( + args.min_baseline is not None + and baseline is not None + and baseline < args.min_baseline + ): + continue + href_full = urljoin(url, href) + local_path = os.path.join(local_dir, href) + print(f"Downloading {href_full} to {local_path}") + task = asyncio.create_task( + save_text_file(session, href_full, local_path) + ) + tasks.append(task) + + await asyncio.gather(*tasks) except Exception as e: print(f"Error on url {url}: {e}") From 59454a12e59dbe01e5a45e0683ef9da8a161b1ef Mon Sep 17 00:00:00 2001 From: nivaldoh Date: Thu, 29 Jan 2026 18:55:16 -0300 Subject: [PATCH 3/9] aggregate by rank baseline --- .../usage_stats/create_usage_jsons.py | 50 ++++++++++++------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/metamon/backend/team_prediction/usage_stats/create_usage_jsons.py b/metamon/backend/team_prediction/usage_stats/create_usage_jsons.py index de1bf1def..2d22c162b 100644 --- a/metamon/backend/team_prediction/usage_stats/create_usage_jsons.py +++ b/metamon/backend/team_prediction/usage_stats/create_usage_jsons.py @@ -1,6 +1,7 @@ import os import json import argparse +from collections import defaultdict from tqdm import tqdm from metamon.backend.team_prediction.usage_stats.format_rules import ( @@ -20,54 +21,69 @@ def main(args): for gen in range(1, 10): for year in range(2014, 2026): for month in range(1, 13): + date = f"{year}-{month:02d}" stat_dir = os.path.join(args.smogon_stat_dir) - valid_movesets = [] + valid_movesets_by_rank = defaultdict(list) for format in VALID_TIERS: format_name = f"gen{gen}{format.name.lower()}" - stat = SmogonStat( + + ranks = SmogonStat.available_ranks( format_name, raw_stats_dir=stat_dir, - date=f"{year}-{month:02d}", + date=date, ) - if stat.movesets: - # if we find data for this, save it + + for rank in ranks: + stat = SmogonStat( + format_name, + raw_stats_dir=stat_dir, + date=date, + rank=rank, + verbose=False, + ) + if not stat.movesets: + continue + path = os.path.join( args.save_dir, "movesets_data", f"gen{gen}", f"{format.name.lower()}", - f"{year}-{month:02d}.json", + rank, + f"{date}.json", ) os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w") as f: json.dump(stat.movesets, f) - valid_movesets.append(stat.movesets) - check_cheatsheet = {} - for mon in stat.movesets.keys(): - checks = stat.movesets[mon]["checks"] - check_cheatsheet[mon] = checks + check_cheatsheet = { + mon: stat.movesets[mon]["checks"] + for mon in stat.movesets.keys() + } path = os.path.join( args.save_dir, "checks_data", f"gen{gen}", f"{format.name.lower()}", - f"{year}-{month:02d}.json", + rank, + f"{date}.json", ) os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w") as f: json.dump(check_cheatsheet, f) + + valid_movesets_by_rank[rank].append(stat.movesets) pbar.update(1) - if valid_movesets: - # merge all the tiers. used to lookup rare Pokémon choices, i.e. fooling around - # with low-tier Pokémon in OverUsed - inclusive_movesets = merge_movesets(valid_movesets) + + for rank, tier_movesets in valid_movesets_by_rank.items(): + inclusive_movesets = merge_movesets(tier_movesets) path = os.path.join( args.save_dir, "movesets_data", f"gen{gen}", "all_tiers", - f"{year}-{month:02d}.json", + rank, + f"{date}.json", ) os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w") as f: From d961fc070629514d9b7d8fe0163c0b61c012b484 Mon Sep 17 00:00:00 2001 From: nivaldoh Date: Thu, 29 Jan 2026 22:01:19 -0300 Subject: [PATCH 4/9] enable aggregation by Smogon rank baselines --- .../usage_stats/stat_reader.py | 391 ++++++++++++++++-- 1 file changed, 358 insertions(+), 33 deletions(-) diff --git a/metamon/backend/team_prediction/usage_stats/stat_reader.py b/metamon/backend/team_prediction/usage_stats/stat_reader.py index e6fc86a94..7e000dbb1 100644 --- a/metamon/backend/team_prediction/usage_stats/stat_reader.py +++ b/metamon/backend/team_prediction/usage_stats/stat_reader.py @@ -1,11 +1,10 @@ import os -import copy import re import json import datetime import functools import warnings -from typing import Optional +from typing import Optional, Union, List from termcolor import colored import metamon @@ -29,6 +28,79 @@ EARLIEST_USAGE_STATS_DATE = datetime.date(2014, 1, 1) LATEST_USAGE_STATS_DATE = datetime.date(2025, 12, 1) +DEFAULT_USAGE_RANK = "1500" + +RankLike = Optional[Union[str, int, float]] + + +def normalize_rank(rank: RankLike) -> Optional[str]: + """ + Normalize a rank/baseline representation to a stable string used in filenames/dirs. + Examples: + 1630.0 -> "1630" + "1760" -> "1760" + None -> None + """ + if rank is None: + return None + s = str(rank).strip() + if re.fullmatch(r"\d+(\.0+)?", s): + s = s.split(".")[0] + return s + + +def rank_from_moveset_filename(fmt: str, filename: str) -> Optional[str]: + """ + Extract the baseline/rank from a Smogon moveset filename. + Examples: gen1ou-0.txt, gen1ou-1500.txt, gen1ou-1630.txt, gen1ou-1760.txt + Returns normalized rank string or None if not applicable. + """ + if not filename.startswith(fmt): + return None + if filename.endswith(".txt.gz") or filename.endswith(".gz"): + return None + if not filename.endswith(".txt"): + return None + + stem = filename[:-4] + if stem == fmt: + # Smogon convention: no explicit baseline means 1500. + return "1500" + + m = re.match(rf"^{re.escape(fmt)}-(\d+(?:\.\d+)?)$", stem) + if not m: + return None + return normalize_rank(m.group(1)) + + +def list_available_ranks_in_moveset_dir(moveset_dir: str, fmt: str) -> List[str]: + if not os.path.isdir(moveset_dir): + return [] + ranks = set() + for fn in os.listdir(moveset_dir): + r = rank_from_moveset_filename(fmt, fn) + if r is not None: + ranks.add(r) + return sorted(ranks, key=lambda x: float(x)) + + +def list_available_usage_ranks(format: str) -> List[str]: + """ + List available baseline/rank subdirectories in the processed usage-stats dataset + for a given format (e.g., gen4ou). + """ + gen, tier = int(format[3]), format[4:] + usage_stats_path = metamon.data.download.download_usage_stats(gen) + base = os.path.join(usage_stats_path, "movesets_data", f"gen{gen}", f"{tier}") + if not os.path.isdir(base): + return [] + ranks = [d for d in os.listdir(base) if os.path.isdir(os.path.join(base, d))] + return sorted( + ranks, + key=lambda x: float(x) if re.fullmatch(r"\d+(\.\d+)?", x) else x, + ) + + def parse_pokemon_moveset(file_path): @@ -267,17 +339,27 @@ def __init__( self.data_paths = [os.path.join(raw_stats_dir, date) for date in dates] self.format = format - self.rank = rank + self.rank = normalize_rank(rank) self.verbose = verbose self._movesets = {} self._inclusive = {} self._usage = None + self._available_ranks: List[str] = [] self._load() self._name_conversion = { pokemon_name(pokemon): pokemon for pokemon in self._movesets.keys() } + @staticmethod + def available_ranks(format: str, raw_stats_dir: str, date: str) -> List[str]: + moveset_dir = os.path.join(raw_stats_dir, date, "moveset") + return list_available_ranks_in_moveset_dir(moveset_dir, format) + + @property + def available_ranks_loaded(self) -> List[str]: + return list(self._available_ranks) + def _load(self): moveset_paths = [] for data_path in self.data_paths: @@ -286,21 +368,49 @@ def _load(self): moveset_paths.append(moveset_path) if len(moveset_paths) == 0: - print(f"No moveset data found for {self.format} in {self.data_paths}") + if self.verbose: + print(f"No moveset data found for {self.format} in {self.data_paths}") self._movesets = {} + self._available_ranks = [] return _movesets = [] + ranks_seen = set() for moveset_path in moveset_paths: - format_data = [ - x for x in os.listdir(moveset_path) if x.startswith(self.format + "-") - ] - if self.rank is not None: - format_data = [x for x in format_data if self.rank in x] - _movesets += [ - parse_pokemon_moveset(os.path.join(moveset_path, x)) - for x in format_data - ] + files_by_rank = {} + for fn in os.listdir(moveset_path): + r = rank_from_moveset_filename(self.format, fn) + if r is None: + continue + files_by_rank.setdefault(r, []).append(fn) + + ranks_seen.update(files_by_rank.keys()) + + if not files_by_rank: + continue + + if self.rank is None: + available = sorted(files_by_rank.keys(), key=lambda x: float(x)) + raise ValueError( + f"SmogonStat requires a baseline/rank for {self.format}. " + f"Available ranks in {moveset_path}: {available}" + ) + filenames = files_by_rank.get(self.rank, []) + + for fn in filenames: + fp = os.path.join(moveset_path, fn) + try: + _movesets.append(parse_pokemon_moveset(fp)) + except Exception as e: + if self.verbose: + warnings.warn(colored(f"Failed parsing {fp}: {e}", "red")) + + self._available_ranks = sorted(ranks_seen, key=lambda x: float(x)) + + if not _movesets: + self._movesets = {} + return + self._movesets = { pokemon_name(k): v for k, v in merge_movesets(_movesets).items() } @@ -360,7 +470,12 @@ def usage(self): def load_between_dates( - dir_path: str, start_year: int, start_month: int, end_year: int, end_month: int + dir_path: str, + start_year: int, + start_month: int, + end_year: int, + end_month: int, + warn_if_empty: bool = True, ) -> dict: start_date = datetime.date(start_year, start_month, 1) end_date = datetime.date(end_year, end_month, 1) @@ -377,6 +492,8 @@ def load_between_dates( selected_data = [] for json_file in os.listdir(dir_path): + if not json_file.endswith(".json"): + continue year, month = json_file.replace(".json", "").split("-") date = datetime.date(year=int(year), month=int(month), day=1) if not start_date <= date <= end_date: @@ -384,8 +501,7 @@ def load_between_dates( with open(os.path.join(dir_path, json_file), "r") as file: data = json.load(file) selected_data.append(data) - if not selected_data: - breakpoint() + if not selected_data and warn_if_empty: warnings.warn( colored( f"No Showdown usage stats found in {dir_path} between {start_date} and {end_date}", @@ -401,10 +517,19 @@ def __init__( format, start_date: datetime.date, end_date: datetime.date, + rank: RankLike = DEFAULT_USAGE_RANK, + rank_fallback: bool = True, + lower_rank_fallback: bool = True, + allow_legacy_layout: bool = True, verbose: bool = True, ): self.format = format.strip().lower() - self.rank = None + self.rank = normalize_rank(rank) + if self.rank is None: + raise ValueError( + f"PreloadedSmogonUsageStats requires rank=... for {self.format} " + "(ranked usage stats are stored by baseline)" + ) self.start_date = start_date self.end_date = end_date self.verbose = verbose @@ -412,12 +537,116 @@ def __init__( gen, tier = int(self.format[3]), self.format[4:] self.gen = gen usage_stats_path = metamon.data.download.download_usage_stats(gen) - movesets_path = os.path.join( + movesets_base = os.path.join( usage_stats_path, "movesets_data", f"gen{gen}", f"{tier}" ) - inclusive_path = os.path.join( + inclusive_base = os.path.join( usage_stats_path, "movesets_data", f"gen{gen}", "all_tiers" ) + movesets_path = os.path.join(movesets_base, self.rank) + inclusive_path = os.path.join(inclusive_base, self.rank) + + def _avail_ranks(base: str) -> list[str]: + if not os.path.isdir(base): + return [] + return sorted( + [d for d in os.listdir(base) if os.path.isdir(os.path.join(base, d))], + key=lambda x: float(x) if re.fullmatch(r"\d+(\.\d+)?", x) else x, + ) + + def _nearest_lower_rank(target: str, candidates: list[str]) -> Optional[str]: + try: + target_val = float(target) + except ValueError: + return None + lower = [] + for r in candidates: + try: + r_val = float(r) + except ValueError: + continue + if r_val < target_val: + lower.append((r_val, r)) + if not lower: + return None + return max(lower, key=lambda x: x[0])[1] + + if not os.path.isdir(movesets_path): + if allow_legacy_layout and os.path.isdir(movesets_base): + # Legacy layout: tier directory contains YYYY-MM.json directly. + if any(f.endswith(".json") for f in os.listdir(movesets_base)): + if self.verbose: + warnings.warn( + colored( + f"Legacy usage-stats layout detected for {self.format}. " + f"Ignoring rank={self.rank} and loading {movesets_base}.", + "yellow", + ) + ) + movesets_path = movesets_base + inclusive_path = inclusive_base + else: + avail = _avail_ranks(movesets_base) + fallback_rank = ( + _nearest_lower_rank(self.rank, avail) if rank_fallback else None + ) + if fallback_rank is not None: + if self.verbose: + warnings.warn( + colored( + f"Requested rank={self.rank} not found for {self.format}. " + f"Falling back to nearest rank={fallback_rank}.", + "yellow", + ) + ) + self.rank = fallback_rank + movesets_path = os.path.join(movesets_base, self.rank) + inclusive_path = os.path.join(inclusive_base, self.rank) + else: + raise FileNotFoundError( + f"Movesets data not found for {self.format} at rank={self.rank}. " + f"Available ranks: {avail}" + ) + else: + avail = _avail_ranks(movesets_base) + fallback_rank = ( + _nearest_lower_rank(self.rank, avail) if rank_fallback else None + ) + if fallback_rank is not None: + if self.verbose: + warnings.warn( + colored( + f"Requested rank={self.rank} not found for {self.format}. " + f"Falling back to nearest rank={fallback_rank}.", + "yellow", + ) + ) + self.rank = fallback_rank + movesets_path = os.path.join(movesets_base, self.rank) + inclusive_path = os.path.join(inclusive_base, self.rank) + else: + raise FileNotFoundError( + f"Movesets data not found for {self.format} at rank={self.rank}. " + f"Available ranks: {avail}" + ) + + if not os.path.isdir(inclusive_path): + if allow_legacy_layout and os.path.isdir(inclusive_base): + if any(f.endswith(".json") for f in os.listdir(inclusive_base)): + inclusive_path = inclusive_base + else: + avail = _avail_ranks(inclusive_base) + raise FileNotFoundError( + f"All-tiers movesets not found for gen{gen} at rank={self.rank}. " + f"Available ranks: {avail}" + ) + else: + avail = _avail_ranks(inclusive_base) + raise FileNotFoundError( + f"All-tiers movesets not found for gen{gen} at rank={self.rank}. " + f"Available ranks: {avail}" + ) + # data is split by year and month if not os.path.exists(movesets_path) or not os.path.exists(inclusive_path): raise FileNotFoundError( @@ -430,6 +659,11 @@ def __init__( end_year=end_date.year, end_month=end_date.month, ) + if not self._movesets: + raise FileNotFoundError( + f"No usage stats found for {self.format} at rank={self.rank} " + f"between {start_date} and {end_date} in {movesets_path}." + ) self._inclusive = load_between_dates( inclusive_path, start_year=EARLIEST_USAGE_STATS_DATE.year, @@ -437,6 +671,40 @@ def __init__( end_year=LATEST_USAGE_STATS_DATE.year, end_month=LATEST_USAGE_STATS_DATE.month, ) + self._lower_rank_fallbacks: list[tuple[str, dict, dict]] = [] + if lower_rank_fallback: + avail = _avail_ranks(movesets_base) + lower_ranks = [ + r for r in avail if float(r) < float(self.rank) + ] + lower_ranks.sort(key=lambda x: float(x), reverse=True) + for r in lower_ranks: + lower_movesets_path = os.path.join(movesets_base, r) + lower_inclusive_path = os.path.join(inclusive_base, r) + if not os.path.isdir(lower_movesets_path) or not os.path.isdir( + lower_inclusive_path + ): + continue + lower_movesets = load_between_dates( + lower_movesets_path, + start_year=start_date.year, + start_month=start_date.month, + end_year=end_date.year, + end_month=end_date.month, + warn_if_empty=False, + ) + lower_inclusive = load_between_dates( + lower_inclusive_path, + start_year=EARLIEST_USAGE_STATS_DATE.year, + start_month=EARLIEST_USAGE_STATS_DATE.month, + end_year=LATEST_USAGE_STATS_DATE.year, + end_month=LATEST_USAGE_STATS_DATE.month, + warn_if_empty=False, + ) + if lower_movesets or lower_inclusive: + self._lower_rank_fallbacks.append( + (r, lower_movesets, lower_inclusive) + ) def _load(self): pass @@ -446,18 +714,38 @@ def _inclusive_search(self, key): key_id = pokemon_name(key) recent = self._movesets.get(key_id, {}) alltime = self._inclusive.get(key_id, {}) - if not (recent or alltime): + if not (recent or alltime or self._lower_rank_fallbacks): return None - if recent and alltime: - # use the alltime stats to selectively get keys that exist - # in recent but are unhelpful for team prediction. - no_info = {"Nothing": 1.0} - for key, value in recent.items(): + no_info = {"Nothing": 1.0} + + def _apply_field_fallback(primary: dict, fallback: dict) -> dict: + if not fallback: + return primary + if not primary: + return fallback + for field, value in fallback.items(): if value == no_info: - if alltime.get(key, {}) != no_info: - recent[key] = alltime[key] - return recent if recent else alltime + continue + if field not in primary or primary.get(field) == no_info: + primary[field] = value + return primary + + # Start with tier stats for the requested rank; do not use all_tiers yet. + primary = recent if recent else {} + + # First, walk downward through lower-rank tier stats. + for _, lower_recent, _ in self._lower_rank_fallbacks: + primary = _apply_field_fallback(primary, lower_recent.get(key_id, {})) + + # If still missing, fall back to all_tiers for the requested rank. + primary = _apply_field_fallback(primary, alltime) + + # Finally, use lower-rank all_tiers as a last resort. + for _, _, lower_alltime in self._lower_rank_fallbacks: + primary = _apply_field_fallback(primary, lower_alltime.get(key_id, {})) + + return primary if primary else None def __getitem__(self, key): entry = Dex.from_gen(self.gen).get_pokedex_entry(key) @@ -475,6 +763,10 @@ def get_usage_stats( format, start_date: Optional[datetime.date] = None, end_date: Optional[datetime.date] = None, + rank: RankLike = DEFAULT_USAGE_RANK, + rank_fallback: bool = True, + lower_rank_fallback: bool = True, + allow_legacy_layout: bool = True, ) -> PreloadedSmogonUsageStats: if start_date is None or start_date < EARLIEST_USAGE_STATS_DATE: start_date = EARLIEST_USAGE_STATS_DATE @@ -486,20 +778,53 @@ def get_usage_stats( else: # force to start of months to prevent cache miss (we only have monthly stats anyway) end_date = datetime.date(end_date.year, end_date.month, 1) - return _cached_smogon_stats(format, start_date, end_date) + rank_norm = normalize_rank(rank) + if rank_norm is None: + rank_norm = DEFAULT_USAGE_RANK + return _cached_smogon_stats( + format, + start_date, + end_date, + rank_norm, + rank_fallback, + lower_rank_fallback, + allow_legacy_layout, + ) @functools.lru_cache(maxsize=64) -def _cached_smogon_stats(format, start_date: datetime.date, end_date: datetime.date): - print(f"Loading usage stats for {format} between {start_date} and {end_date}") +def _cached_smogon_stats( + format, + start_date: datetime.date, + end_date: datetime.date, + rank: Optional[str], + rank_fallback: bool, + lower_rank_fallback: bool, + allow_legacy_layout: bool, +): + if rank is None: + raise ValueError("rank is required for cached usage stats.") + print( + f"Loading usage stats for {format} between {start_date} and {end_date} (rank={rank})" + ) return PreloadedSmogonUsageStats( - format=format, start_date=start_date, end_date=end_date, verbose=False + format=format, + start_date=start_date, + end_date=end_date, + rank=rank, + rank_fallback=rank_fallback, + lower_rank_fallback=lower_rank_fallback, + allow_legacy_layout=allow_legacy_layout, + verbose=False, ) if __name__ == "__main__": stats = get_usage_stats( - "gen9ou", datetime.date(2023, 1, 1), datetime.date(2025, 6, 1) + "gen9ou", + datetime.date(2023, 1, 1), + datetime.date(2025, 6, 1), + rank=DEFAULT_USAGE_RANK, ) print(len(stats.usage)) for mon in sorted( From 6c4158c2adaab8ddf26465228a3f58860fed6a01 Mon Sep 17 00:00:00 2001 From: nivaldoh Date: Thu, 29 Jan 2026 22:02:46 -0300 Subject: [PATCH 5/9] propagate rank argument for usage_stats calls --- metamon/backend/team_prediction/predictor.py | 18 ++++++++++++++++-- metamon/backend/team_prediction/team.py | 4 ++-- .../usage_stats/legacy_team_builder.py | 10 ++++++++-- metamon/backend/team_prediction/vocabulary.py | 10 ++++++++-- metamon/baselines/base.py | 10 ++++++++-- metamon/tokenizer/tokenizer.py | 7 +++++-- tools/patch_pokeagent_gen9ou_trajs.py | 10 ++++++++-- 7 files changed, 55 insertions(+), 14 deletions(-) diff --git a/metamon/backend/team_prediction/predictor.py b/metamon/backend/team_prediction/predictor.py index fa0fc61b2..10758cc5c 100644 --- a/metamon/backend/team_prediction/predictor.py +++ b/metamon/backend/team_prediction/predictor.py @@ -17,14 +17,23 @@ ) from metamon.backend.team_prediction.usage_stats import ( PreloadedSmogonUsageStats, + DEFAULT_USAGE_RANK, + RankLike, ) from metamon.backend.replay_parser.str_parsing import pokemon_name from metamon.backend.team_prediction.team import TeamSet, PokemonSet, Roster class TeamPredictor(ABC): - def __init__(self, replay_stats_dir: Optional[str] = None): + def __init__( + self, + replay_stats_dir: Optional[str] = None, + usage_stats_rank: RankLike = DEFAULT_USAGE_RANK, + ): + if usage_stats_rank is None: + usage_stats_rank = DEFAULT_USAGE_RANK self.replay_stats_dir = replay_stats_dir + self.usage_stats_rank = usage_stats_rank def bin_usage_stats_dates( self, date: datetime.date @@ -50,6 +59,7 @@ def get_legacy_team_builder(self, format: str, date: datetime.date) -> TeamBuild format=format, start_date=start_date, end_date=end_date, + rank=self.usage_stats_rank, ) def get_usage_stats( @@ -209,9 +219,13 @@ def __init__( top_k_scored_teams: int = 10, top_k_scored_movesets: int = 3, replay_stats_dir: Optional[str] = None, + usage_stats_rank: RankLike = DEFAULT_USAGE_RANK, ): assert not isinstance(top_k_consistent_teams, str) - super().__init__(replay_stats_dir) + super().__init__( + replay_stats_dir, + usage_stats_rank=usage_stats_rank, + ) self.stat_format = None self.top_k_consistent_teams = top_k_consistent_teams self.top_k_consistent_movesets = top_k_consistent_movesets diff --git a/metamon/backend/team_prediction/team.py b/metamon/backend/team_prediction/team.py index 513d1baa3..800216737 100644 --- a/metamon/backend/team_prediction/team.py +++ b/metamon/backend/team_prediction/team.py @@ -15,13 +15,13 @@ unknown, ) from metamon.backend.replay_parser.str_parsing import pokemon_name -from metamon.backend.team_prediction.usage_stats import get_usage_stats +from metamon.backend.team_prediction.usage_stats import get_usage_stats, DEFAULT_USAGE_RANK from metamon.backend.showdown_dex import Dex def moveset_size(pokemon_name: str, gen: int) -> int: # attempts to handle cases where we would expect a Pokemon to have less than 4 moves - stat = get_usage_stats(f"gen{gen}ubers") + stat = get_usage_stats(f"gen{gen}ubers", rank=DEFAULT_USAGE_RANK) try: moves = len(set(stat[pokemon_name]["moves"].keys()) - {"Nothing"}) except KeyError: diff --git a/metamon/backend/team_prediction/usage_stats/legacy_team_builder.py b/metamon/backend/team_prediction/usage_stats/legacy_team_builder.py index 024536b8f..0f9b55edf 100644 --- a/metamon/backend/team_prediction/usage_stats/legacy_team_builder.py +++ b/metamon/backend/team_prediction/usage_stats/legacy_team_builder.py @@ -7,7 +7,7 @@ import numpy as np import metamon -from metamon.backend.team_prediction.usage_stats import get_usage_stats +from metamon.backend.team_prediction.usage_stats import get_usage_stats, RankLike from metamon.backend.team_prediction.usage_stats.constants import ( HIDDEN_POWER_IVS, HIDDEN_POWER_DVS, @@ -43,12 +43,18 @@ def __init__( format: str, start_date: datetime.date, end_date: datetime.date, + rank: RankLike = None, verbose: bool = False, remove_banned: bool = False, ): self.format = format self.gen = metamon.backend.format_to_gen(format) - self.stat = get_usage_stats(format, start_date, end_date) + self.stat = get_usage_stats( + format, + start_date, + end_date, + rank=rank, + ) if remove_banned: self.stat.remove_banned_pm() self.verbose = verbose diff --git a/metamon/backend/team_prediction/vocabulary.py b/metamon/backend/team_prediction/vocabulary.py index c101aa5ee..ce305cc00 100644 --- a/metamon/backend/team_prediction/vocabulary.py +++ b/metamon/backend/team_prediction/vocabulary.py @@ -11,7 +11,10 @@ import metamon from metamon.tokenizer import PokemonTokenizer, UNKNOWN_TOKEN from metamon.backend.team_prediction.team import PokemonSet, TeamSet -from metamon.backend.team_prediction.usage_stats import get_usage_stats +from metamon.backend.team_prediction.usage_stats import ( + get_usage_stats, + DEFAULT_USAGE_RANK, +) def create_vocabularies(scan_dataset: bool = False): @@ -42,7 +45,10 @@ def create_vocabularies(scan_dataset: bool = False): for tier in ["ou", "uu", "ubers", "nu"]: format = f"gen{gen}{tier}" stat = get_usage_stats( - format, start_date=date(2015, 1, 1), end_date=date(2025, 1, 1) + format, + start_date=date(2015, 1, 1), + end_date=date(2025, 1, 1), + rank=DEFAULT_USAGE_RANK, ) for pokemon_name, data in stat._inclusive.items(): diff --git a/metamon/baselines/base.py b/metamon/baselines/base.py index 2e20cab1b..fc5580306 100644 --- a/metamon/baselines/base.py +++ b/metamon/baselines/base.py @@ -18,7 +18,10 @@ ) from metamon.baselines import GEN_DATA -from metamon.backend.team_prediction.usage_stats import get_usage_stats +from metamon.backend.team_prediction.usage_stats import ( + get_usage_stats, + DEFAULT_USAGE_RANK, +) class Baseline(Player, ABC): @@ -543,7 +546,10 @@ def switch_scores( gen, format = self.get_gen_format(battle) if check_w > 0: - smogon_stats = get_usage_stats(f"gen{gen}{format.lower()}") + smogon_stats = get_usage_stats( + f"gen{gen}{format.lower()}", + rank=DEFAULT_USAGE_RANK, + ) switch_scores = {} for switch in switches: diff --git a/metamon/tokenizer/tokenizer.py b/metamon/tokenizer/tokenizer.py index 118d92e2f..fe94d758f 100644 --- a/metamon/tokenizer/tokenizer.py +++ b/metamon/tokenizer/tokenizer.py @@ -117,7 +117,10 @@ def get_tokenizer(choice: str) -> PokemonTokenizer: DefaultActionSpace, ) from metamon.data import ParsedReplayDataset - from metamon.backend.team_prediction.usage_stats import get_usage_stats + from metamon.backend.team_prediction.usage_stats import ( + get_usage_stats, + DEFAULT_USAGE_RANK, + ) parser = ArgumentParser() parser.add_argument("--parsed_replay_root", required=True) @@ -133,7 +136,7 @@ def get_tokenizer(choice: str) -> PokemonTokenizer: # catch stray names from Smogon stats for format in SUPPORTED_BATTLE_FORMATS: - stat = get_usage_stats(format) + stat = get_usage_stats(format, rank=DEFAULT_USAGE_RANK) for pokemon_name_str, data in tqdm.tqdm(stat._inclusive.items()): tokenizer.add_token_for(pokemon_name(pokemon_name_str)) diff --git a/tools/patch_pokeagent_gen9ou_trajs.py b/tools/patch_pokeagent_gen9ou_trajs.py index 7f005e485..8d960be2f 100644 --- a/tools/patch_pokeagent_gen9ou_trajs.py +++ b/tools/patch_pokeagent_gen9ou_trajs.py @@ -18,12 +18,18 @@ from multiprocessing import Pool, cpu_count from metamon.interface import UniversalState -from metamon.backend.team_prediction.usage_stats import get_usage_stats +from metamon.backend.team_prediction.usage_stats import ( + get_usage_stats, + DEFAULT_USAGE_RANK, +) from metamon.backend.replay_parser.str_parsing import clean_name USAGE_STATS = get_usage_stats( - "gen9ou", start_date=date(2022, 1, 1), end_date=date(2025, 5, 31) + "gen9ou", + start_date=date(2022, 1, 1), + end_date=date(2025, 5, 31), + rank=DEFAULT_USAGE_RANK, ) From 7fe9c218c1941a3ce5b2924376b6b7670de17ed8 Mon Sep 17 00:00:00 2001 From: nivaldoh Date: Thu, 29 Jan 2026 22:03:13 -0300 Subject: [PATCH 6/9] propagate rank argument --- metamon/backend/team_prediction/usage_stats/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/metamon/backend/team_prediction/usage_stats/__init__.py b/metamon/backend/team_prediction/usage_stats/__init__.py index 3eeaf626d..c19f6b3f7 100644 --- a/metamon/backend/team_prediction/usage_stats/__init__.py +++ b/metamon/backend/team_prediction/usage_stats/__init__.py @@ -1,2 +1,8 @@ -from .stat_reader import get_usage_stats, PreloadedSmogonUsageStats +from .stat_reader import ( + get_usage_stats, + PreloadedSmogonUsageStats, + DEFAULT_USAGE_RANK, + RankLike, + list_available_usage_ranks, +) from .legacy_team_builder import TeamBuilder, PokemonStatsLookupError From d00e219aae40895c0cbd785db9dae3fc38681d00 Mon Sep 17 00:00:00 2001 From: nivaldoh Date: Thu, 29 Jan 2026 22:04:03 -0300 Subject: [PATCH 7/9] propagate rank argument for usage_stats calls --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1eabd4a3e..1e39ca495 100644 --- a/README.md +++ b/README.md @@ -1130,7 +1130,8 @@ from metamon.backend.team_prediction.usage_stats import get_usage_stats from datetime import date usage_stats = get_usage_stats("gen1ou", start_date=date(2017, 12, 1), - end_date=date(2018, 3, 30) + end_date=date(2018, 3, 30), + rank="1500", ) alakazam_info: dict = usage_stats["Alakazam"] # non alphanum chars and case are flexible ``` From 67ca0b8ae109ceba1106c4c15af5c4056e0d1595 Mon Sep 17 00:00:00 2001 From: nivaldoh Date: Thu, 29 Jan 2026 22:04:28 -0300 Subject: [PATCH 8/9] add sanity checks for stats_usage --- .../usage_stats/test_usage_stats.py | 29 ++++++++++++ .../test_usage_stats_rank_fallback.py | 44 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 metamon/backend/team_prediction/usage_stats/test_usage_stats.py create mode 100644 metamon/backend/team_prediction/usage_stats/test_usage_stats_rank_fallback.py diff --git a/metamon/backend/team_prediction/usage_stats/test_usage_stats.py b/metamon/backend/team_prediction/usage_stats/test_usage_stats.py new file mode 100644 index 000000000..aeda26e46 --- /dev/null +++ b/metamon/backend/team_prediction/usage_stats/test_usage_stats.py @@ -0,0 +1,29 @@ +import datetime + +from metamon.backend.team_prediction.usage_stats import get_usage_stats + +FORMAT = "gen9ou" +START_DATE = datetime.date(2025, 1, 1) +END_DATE = datetime.date(2025, 6, 1) +RANK = 1500 + +stats = get_usage_stats(FORMAT, START_DATE, END_DATE, rank=RANK) + +print("format", FORMAT) +print("date_window", START_DATE, END_DATE) +print("rank_requested", RANK) +print("rank_used", stats.rank) +print("pokemon_count", len(stats.usage)) + +# Show a small sample so we can sanity-check distributions. +top_n = 5 +for mon in stats.usage[:top_n]: + data = stats[mon] + top_moves = sorted(data.get("moves", {}).items(), key=lambda x: x[1], reverse=True)[:3] + top_items = sorted(data.get("items", {}).items(), key=lambda x: x[1], reverse=True)[:3] + top_abilities = sorted(data.get("abilities", {}).items(), key=lambda x: x[1], reverse=True)[:3] + print("\n", mon) + print(" count", data.get("count")) + print(" top_moves", top_moves) + print(" top_items", top_items) + print(" top_abilities", top_abilities) diff --git a/metamon/backend/team_prediction/usage_stats/test_usage_stats_rank_fallback.py b/metamon/backend/team_prediction/usage_stats/test_usage_stats_rank_fallback.py new file mode 100644 index 000000000..13ba18995 --- /dev/null +++ b/metamon/backend/team_prediction/usage_stats/test_usage_stats_rank_fallback.py @@ -0,0 +1,44 @@ +import os +import json +import datetime +from pathlib import Path + +if "METAMON_CACHE_DIR" not in os.environ: + os.environ["METAMON_CACHE_DIR"] = "/tmp/usage_stats_test" + +from metamon.backend.team_prediction.usage_stats import get_usage_stats + +base = Path('/tmp/usage_stats_test/usage-stats') +movesets_base = base / 'movesets_data' / 'gen9' +checks_base = base / 'checks_data' / 'gen9' + +for p in [ + movesets_base / 'ou' / '1630', + movesets_base / 'ou' / '1500', + movesets_base / 'all_tiers' / '1630', + movesets_base / 'all_tiers' / '1500', + checks_base +]: + p.mkdir(parents=True, exist_ok=True) + +month = '2025-01' +json.dump( + {'pikachu': {'count': 1, 'moves': {'Nothing': 1.0}}}, + open(movesets_base / 'ou' / '1630' / f'{month}.json', 'w'), +) +json.dump( + {'pikachu': {'count': 1, 'moves': {'Thunderbolt': 1.0}}}, + open(movesets_base / 'ou' / '1500' / f'{month}.json', 'w'), +) +json.dump( + {'pikachu': {'count': 1, 'moves': {'Iron Tail': 1.0}}}, + open(movesets_base / 'all_tiers' / '1630' / f'{month}.json', 'w'), +) +json.dump( + {'pikachu': {'count': 1, 'moves': {'Surf': 1.0}}}, + open(movesets_base / 'all_tiers' / '1500' / f'{month}.json', 'w'), +) + +stats = get_usage_stats('gen9ou', datetime.date(2025,1,1), datetime.date(2025,1,1), rank=1649) +print("rank_used", stats.rank) # expect 1630 (nearest lower) +print("pikachu_moves", stats['pikachu']['moves']) # expect Thunderbolt (lower-rank tier, before all_tiers) From b8ca468343a59c3867a4e19808fc330e5b48aee3 Mon Sep 17 00:00:00 2001 From: nivaldoh Date: Thu, 29 Jan 2026 22:51:43 -0300 Subject: [PATCH 9/9] add exponential backoff, limit concurrency --- .../usage_stats/stat_scraper.py | 223 ++++++++++-------- 1 file changed, 129 insertions(+), 94 deletions(-) diff --git a/metamon/backend/team_prediction/usage_stats/stat_scraper.py b/metamon/backend/team_prediction/usage_stats/stat_scraper.py index ffe9f7115..c240f6f41 100644 --- a/metamon/backend/team_prediction/usage_stats/stat_scraper.py +++ b/metamon/backend/team_prediction/usage_stats/stat_scraper.py @@ -50,12 +50,30 @@ action="store_true", help="Download chaos/ JSON files (includes info.cutoff metadata).", ) +parser.add_argument( + "--max_concurrency", + type=int, + default=8, + help="Maximum number of concurrent HTTP requests.", +) +parser.add_argument( + "--max_retries", + type=int, + default=5, + help="Maximum number of retries for a failed request.", +) +parser.add_argument( + "--backoff_base", + type=float, + default=0.5, + help="Base seconds for exponential backoff (retry delay = base * 2^attempt).", +) args = parser.parse_args() - - + + BASELINE_RE = re.compile(r"-(\d+(?:\.\d+)?)\.(txt|json)$") -SKIP_DIRS = {"monotype", "metagame"} +SKIP_DIRS = {"monotype", "metagame", "leads"} if not args.include_chaos: SKIP_DIRS.add("chaos") @@ -81,108 +99,125 @@ def ensure_dir(file_path): os.makedirs(file_path) -async def save_text_file(session, url, local_path): +async def save_text_file(session, url, local_path): # Check if the file already exists if os.path.isfile(local_path): print(f"File already exists: {local_path}") return - async with session.get(url) as response: - if response.status == 200: - text = await response.text() - async with aiofiles.open(local_path, "w", encoding="utf-8") as file: - await file.write(text) - - -async def scrape_base(session, url, local_dir, start_date, end_date): - async with session.get(url) as response: - text = await response.text() - soup = BeautifulSoup(text, "html.parser") - - tasks = [] - for link in soup.find_all("a"): - href = link.get("href") - if href and not href.startswith("?") and href != "../": - href_date = int(href[:4]) - href_full = urljoin(url, href) - local_path = os.path.join(local_dir, href) - - if ( - href.endswith("/") - and href_date >= start_date - and href_date < end_date - ): # It's a directory - ensure_dir(local_path) - task = asyncio.create_task(scrape(session, href_full, local_path)) - tasks.append(task) - - await asyncio.gather(*tasks) - + async with session.get(url) as response: + if response.status == 200: + text = await response.text() + async with aiofiles.open(local_path, "w", encoding="utf-8") as file: + await file.write(text) + + +async def _fetch_with_retries(session, url): + for attempt in range(args.max_retries + 1): + try: + async with session.get(url) as response: + if response.status != 200: + raise RuntimeError(f"HTTP {response.status}") + return await response.text() + except Exception: + if attempt >= args.max_retries: + raise + await asyncio.sleep(args.backoff_base * (2**attempt)) + + +async def scrape_base(session, url, local_dir, start_date, end_date, sem): + async with sem: + text = await _fetch_with_retries(session, url) + soup = BeautifulSoup(text, "html.parser") + tasks = [] + for link in soup.find_all("a"): + href = link.get("href") + if href and not href.startswith("?") and href != "../": + href_date = int(href[:4]) + href_full = urljoin(url, href) + local_path = os.path.join(local_dir, href) + + if ( + href.endswith("/") + and href_date >= start_date + and href_date < end_date + ): # It's a directory + ensure_dir(local_path) + task = asyncio.create_task( + scrape(session, href_full, local_path, sem) + ) + tasks.append(task) + + await asyncio.gather(*tasks) -async def scrape(session, url, local_dir): - try: - async with session.get(url) as response: - text = await response.text() - soup = BeautifulSoup(text, "html.parser") - tasks = [] - for link in soup.find_all("a"): - href = link.get("href") - if not href or href.startswith("?"): +async def scrape(session, url, local_dir, sem): + try: + async with sem: + text = await _fetch_with_retries(session, url) + soup = BeautifulSoup(text, "html.parser") + + tasks = [] + for link in soup.find_all("a"): + href = link.get("href") + if not href or href.startswith("?"): + continue + if href.endswith("/"): + if href == "../": continue - if href.endswith("/"): - if href == "../": - continue - dirname = href.rstrip("/") - if dirname in SKIP_DIRS: - continue - href_full = urljoin(url, href) - local_path = os.path.join(local_dir, href) - ensure_dir(local_path) - task = asyncio.create_task( - scrape(session, href_full, local_path) - ) - tasks.append(task) + dirname = href.rstrip("/") + if dirname in SKIP_DIRS: continue + href_full = urljoin(url, href) + local_path = os.path.join(local_dir, href) + ensure_dir(local_path) + task = asyncio.create_task( + scrape(session, href_full, local_path, sem) + ) + tasks.append(task) + continue - if href.endswith(".txt") or href.endswith(".json"): - baseline = extract_baseline(href) - if ( - allowed_baselines is not None - and baseline is not None - and baseline not in allowed_baselines - ): - continue - if ( - args.min_baseline is not None - and baseline is not None - and baseline < args.min_baseline - ): - continue - href_full = urljoin(url, href) - local_path = os.path.join(local_dir, href) - print(f"Downloading {href_full} to {local_path}") - task = asyncio.create_task( - save_text_file(session, href_full, local_path) - ) - tasks.append(task) - - await asyncio.gather(*tasks) - except Exception as e: - print(f"Error on url {url}: {e}") + if href.endswith(".txt") or href.endswith(".json"): + baseline = extract_baseline(href) + if ( + allowed_baselines is not None + and baseline is not None + and baseline not in allowed_baselines + ): + continue + if ( + args.min_baseline is not None + and baseline is not None + and baseline < args.min_baseline + ): + continue + href_full = urljoin(url, href) + local_path = os.path.join(local_dir, href) + print(f"Downloading {href_full} to {local_path}") + task = asyncio.create_task( + save_text_file(session, href_full, local_path) + ) + tasks.append(task) + + await asyncio.gather(*tasks) + except Exception as e: + print(f"Error on url {url}: {e}") ensure_dir(args.save_dir) -async def main(): - async with aiohttp.ClientSession() as session: - await scrape_base( - session, - base_url, - args.save_dir, - start_date=args.start_date, - end_date=args.end_date, - ) - - -asyncio.run(main()) +async def main(): + connector = aiohttp.TCPConnector(limit=args.max_concurrency) + sem = asyncio.Semaphore(args.max_concurrency) + async with aiohttp.ClientSession(connector=connector) as session: + await scrape_base( + session, + base_url, + args.save_dir, + start_date=args.start_date, + end_date=args.end_date, + sem=sem, + ) + + +asyncio.run(main())