From 239250a7fb270f3c0899dfd519021f764386a493 Mon Sep 17 00:00:00 2001 From: alinakbase Date: Wed, 3 Sep 2025 10:58:47 -0700 Subject: [PATCH 1/2] Add files via upload --- src/parsers/refseq_api.py | 899 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 899 insertions(+) create mode 100644 src/parsers/refseq_api.py diff --git a/src/parsers/refseq_api.py b/src/parsers/refseq_api.py new file mode 100644 index 0000000..66b690a --- /dev/null +++ b/src/parsers/refseq_api.py @@ -0,0 +1,899 @@ +import os +import uuid +from datetime import datetime, date +import re +import click +import pandas as pd +import requests +from typing import Optional +from typing import Any, Literal +from pyspark.sql.types import StructType, StructField, StringType + +from pyspark.sql import SparkSession +from delta import configure_spark_with_delta_pip + + +""" + +python refseq_api.py \ + --taxid "224325, 2741724, 193567" \ + --database refseq_api \ + --mode overwrite \ + --debug \ + --unique-per-taxon + +""" + + +# ---------------- Spark + Delta ---------------- + +def build_spark(database: str) -> SparkSession: + """ + Initialize a Spark session with Delta Lake support and create the specified database if it doesn't exist. + """ + builder = ( + SparkSession.builder.appName("NCBI Datasets -> CDM") + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") + ) + spark = configure_spark_with_delta_pip(builder).getOrCreate() + + # Create the database namespace if it doesn't already exist + spark.sql(f"CREATE DATABASE IF NOT EXISTS {database}") + return spark + + +def write_delta( + spark: SparkSession, + pandas_df: pd.DataFrame, + database: str, + table: str, + mode: str = "append") -> None: + """ + Write Pandas DataFrame to a Delta Lake table using Spark. + Supports append or overwrite. + Special handling for 'contig_collection' schema to avoid inference error. + """ + + if pandas_df is None or pandas_df.empty: + print(f"No data to write to {database}.{table}") + return + + print(f"Writing {table} with {len(pandas_df)} rows") + print(pandas_df.dtypes) + print(pandas_df.head(10)) + + # -------- Special schema for problematic table -------- + if table == "contig_collection": + schema = StructType([ + StructField("collection_id", StringType(), True), + StructField("contig_collection_type", StringType(), True), + StructField("ncbi_taxon_id", StringType(), True), + StructField("gtdb_taxon_id", StringType(), True), + ]) + pandas_df = pandas_df.astype(str).where(pandas_df.notnull(), None) + spark_df = spark.createDataFrame(pandas_df, schema=schema) + else: + spark_df = spark.createDataFrame(pandas_df) + + # -------- Writer with proper options -------- + writer = spark_df.write.format("delta").mode(mode) + if mode == "append": + writer = writer.option("mergeSchema", "true") + elif mode == "overwrite": + writer = writer.option("overwriteSchema", "true") + + writer.saveAsTable(f"{database}.{table}") + print(f"Saved {len(pandas_df)} rows to {database}.{table} (mode={mode})") + + +def preview_or_skip( + spark: SparkSession, + database: str, + table: str, + limit: int = 20) -> None: + + """ + Preview the first N rows of a Delta table if it exists. + """ + + full_table = f"{database}.{table}" + if spark.catalog.tableExists(full_table): + print(f"Showing the first {limit} row from {full_table}:") + spark.sql(f"SELECT * FROM {full_table} LIMIT {limit}").show(truncate=False) + else: + print(f"Table {full_table} not found. Skipping preview.") + + +# ---------------- Helpers ---------------- + +def parse_taxid_args(taxid_arg: Optional[str], taxid_file: Optional[str]) -> list[str]: + """ + Parse and collect valid numeric TaxIDs from command-line arguments and file. + """ + + ## empty list to collect taxids,avoid the duplicate TaxIDs + taxids: list[str] = [] + + # Parse taxid argument: --taxid "224325, 2741724" + if taxid_arg: + id_list = taxid_arg.split(",") ## separate them into a list using commas + for num in id_list: + # Keep only digits + id = re.sub(r"\D+", "", num.strip()) ## Remove all non-numeric characters + if id: + taxids.append(id) + + # Parse --taxid-file + if taxid_file: + if not os.path.exists(taxid_file): + raise click.BadParameter(f"Path '{taxid_file}' does not exist.", param_hint="--taxid-file") + with open(taxid_file, "r", encoding="utf-8") as f: + for line in f: + id = re.sub(r"\D+", "", line.strip()) + if id: + taxids.append(id) + + # Deduplicate while preserving order + seen = set() + unique_taxids = [] + for id in taxids: + if id not in seen: + seen.add(id) + unique_taxids.append(id) + + return unique_taxids + + +# ---------------- NCBI Datasets v2 ---------------- + +def fetch_reports_by_taxon( + taxon: str, + api_key: str | None = None, + page_size: int = 500, + refseq_only: bool = True, + current_only: bool = True, + debug: bool = False, +): + """ + Generator to iterate through genome dataset reports from NCBI Datasets v2 API by TaxID. + + Features: + - Calls the NCBI Datasets v2 endpoint for genome reports. + - Applies filters: RefSeq only / current assemblies only. + - Handles pagination via `next_page_token`. + - Yields report dicts for each assembly. + + """ + + # ---------------- API endpoint ---------------- + # Base URL for NCBI Datasets REST API v2 + base = "https://api.ncbi.nlm.nih.gov/datasets/v2" + url = f"{base}/genome/taxon/{taxon}/dataset_report" + + # ---------------- Request params ---------------- + # metadata + assembly report text + params = { + "page_size": page_size, + "returned_content": "COMPLETE", + "filters.report_type": "assembly_report" + } + + if current_only: + params["filters.assembly_version"] = "current" + if refseq_only: + params["filters.assembly_source"] = "refseq" + + # ---------------- Headers ---------------- + headers = {"Accept": "application/json"} + if api_key: + headers["api-key"] = api_key + + # ---------------- Pagination loop ---------------- + token = None + while True: + if token: + params["page_token"] = token + + # ---- request ---- + try: + resp = requests.get(url, params=params, headers=headers, timeout=60) + resp.raise_for_status() + payload = resp.json() + except (requests.RequestException, ValueError) as e: + print(f"Request failed for taxon {taxon}: {e}") + break + + # ---- Extract reports ---- + reports = payload.get("reports", []) + if not reports: + print(f"No reports returned for taxon {taxon}") + break + + # ---------------- Filter loop ---------------- + for rep in reports: + info = rep.get("assemblyInfo") or rep.get("assembly_info") or {} + src_db = info.get("sourceDatabase") + + # Skip if explicitly marked as GenBank + if src_db and src_db != "SOURCE_DATABASE_REFSEQ": + continue + + # Print source info if debugging + if debug: + if src_db is None: + print(f"[DEBUG] accession={rep.get('accession')} has no sourceDatabase field") + else: + print(f"[DEBUG] accession={rep.get('accession')} sourceDatabase={src_db}") + + # Print the first 200 chars of assemblyReport for inspection + if info.get("assemblyReport"): + snippet = info["assemblyReport"][:200].replace("\n", " ") + print(f"[DEBUG] {snippet}") + + # Yield one assembly report to caller + yield rep + + # ---------------- Handle pagination ---------------- + token = payload.get("next_page_token") + if not token: + break + + + +# ---------------- Robust extractors set up ---------------- + +# regex patterns +PAT_BIOSAMPLE = re.compile(r"\bSAMN\d+\b") +PAT_BIOPROJECT = re.compile(r"\bPRJNA\d+\b") +PAT_GCF = re.compile(r"\bGCF_\d{9}\.\d+\b") +PAT_GCA = re.compile(r"\bGCA_\d{9}\.\d+\b") + + +def _coalesce(*vals: Any) -> str | None: + """ + Return the first non-empty, non-whitespace string from a list of inputs. + """ + for v in vals: + if isinstance(v, str): + trimmed = v.strip() + if trimmed: + return trimmed + return None + + +def _deep_find_str(obj: Any, target_keys: set[str]) -> str | None: + """ + Recursively search a nested dict/list structure for the first non-empty string value under any of the target_keys. + _deep_find_str({"assemblyDate": "2000-12-01"}, {"assemblyDate"}) -> "2000-12-01" + + """ + + ## If the object is a dictionary + if isinstance(obj, dict): + for k, v in obj.items(): + if isinstance(k, str) and k in target_keys and isinstance(v, str) and v.strip(): + return v.strip() + + ## recursively search the v + found = _deep_find_str(v, target_keys) + if found: + return found + + ## If the object is a list, search each element + elif isinstance(obj, list): + for it in obj: + found = _deep_find_str(it, target_keys) + if found: + return found + return None + + +def _deep_collect_regex(obj: Any, pattern: re.Pattern) -> list[str]: + """ + Recursively collect all unique regex matches from a nested structure + + Args: + obj: The nested object to search (can be dict, list, or string). + pattern: Compiled regex pattern to search for within string values. + + Returns: + A sorted list of unique regex matches found anywhere inside the object. + + """ + + results = set() # Use a set to avoid duplicate matches + + def _walk(x): + if isinstance(x, dict): + # Recursively process each value in the dictionary + for v in x.values(): + _walk(v) + elif isinstance(x, list): + # Recursively process each element in the list + for v in x: + _walk(v) + elif isinstance(x, str): + # Apply regex to the string and add matches to results + for m in pattern.findall(x): + results.add(m) + + # Start recursion + _walk(obj) + + # Convert set to sorted list for consistent ordering + return sorted(results) + + +# ---------------- Robust extractors ---------------- + +def extract_created_date(rep: dict[str, Any], allow_genbank_date: bool = False, debug: bool = False) -> str | None: + """ + + Extract creation/release date for a genome assembly. + + Priority: + - For RefSeq: release_date > assembly_date > submission_date + - For GenBank (if allowed): submission_date + Returns None if no valid date is found. + + """ + + # Normalize assembly info + assem_data = rep.get("assembly_info") or rep.get("assemblyInfo") or {} + src_db = rep.get("source_database") or assem_data.get("sourceDatabase") + + # Collect candidate dates + candidates: dict[str, str] = {} + for src in (assem_data, rep.get("assembly") or {}, rep): + for key in ["releaseDate", "assemblyDate", "submissionDate", + "release_date", "assembly_date", "submission_date"]: + v = src.get(key) # safely fetch the value (None if key not found) + + # Only accept non-empty string values + if isinstance(v, str) and v.strip(): + # Normalize the key: "release_date" -> "releasedate" + norm_key = key.lower().replace("_", "") + + # Store the cleaned value in candidates under the normalized key + candidates[norm_key] = v.strip() + + if debug and candidates: + print(f"[DEBUG] found candidates={candidates}, source={src_db}") + + # RefSeq: prioritize release > assembly > submission + if src_db == "SOURCE_DATABASE_REFSEQ": + for pref in ("releasedate", "assemblydate", "submissiondate"): + if pref in candidates: + return candidates[pref] + + # GenBank: fallback only submission date + if allow_genbank_date and src_db == "SOURCE_DATABASE_GENBANK": + if "submissiondate" in candidates: + return candidates["submissiondate"] + + return None + + + +def extract_assembly_name(rep: dict[str, Any]) -> str | None: + """ + Extract the assembly name from a genome report record. + This function normalizes extraction because the assembly name may appear in different parts of the JSON depending on API. + + """ + + assembly_info = rep.get("assemblyInfo") or {} + a = rep.get("assembly") or {} + + # Try to extract directly from the most common locations + v = _coalesce( + assembly_info.get("assemblyName"), + a.get("assemblyName"), + rep.get("assemblyName"), + ) + if v: + return v + + # Fallback: recursively search the nested structure + return _deep_find_str(rep, {"assemblyName", "assembly_name"}) + + + +def extract_organism_name(rep: dict[str, Any]) -> str | None: + assembly_info = rep.get("assemblyInfo") or {} + a = rep.get("assembly") or {} + org_top = rep.get("organism") or {} + + # Candidate locations where the organism name might appear + candidates = [ + org_top.get("organismName"), + org_top.get("scientificName"), + org_top.get("taxName"), + (assembly_info.get("organism") or {}).get("organismName") if isinstance(assembly_info.get("organism"), dict) else None, + (a.get("organism") or {}).get("organismName") if isinstance(a.get("organism"), dict) else None + ] + + # Return the first non-empty candidate string + for v in candidates: + if isinstance(v, str) and v.strip(): + return v.strip() + + # Fallback deep search across nested structure + return _deep_find_str(rep, + {"organismName", "scientificName", "sciName", "taxName", + "displayName", "organism_name"}) + + + +def extract_taxid(rep: dict[str, Any]) -> str | None: + """ + Extract the NCBI Taxonomy ID (taxid) from a genome assembly report. + + """ + + def is_numeric_id(value) -> bool: + """Check if a value is a valid numeric taxid.""" + return isinstance(value, (int, float)) or (isinstance(value, str) and value.strip().isdigit()) + + # --- Try top-level organism block --- + org_top = rep.get("organism") or {} + v = org_top.get("taxId") or org_top.get("taxid") or org_top.get("taxID") + if is_numeric_id(v): + return str(int(v)) + + # --- Recursive search --- + def _deep_find_taxid(x): + if isinstance(x, dict): + for k, v in x.items(): + # normalize key (lowercase) and check if it contains 'taxid' + if isinstance(k, str) and k.lower().replace("_", "") in {"taxid", "taxidvalue"}: + if is_numeric_id(v): + return str(int(v)) + found = _deep_find_taxid(v) + if found: + return found + elif isinstance(x, list): + for it in x: + found = _deep_find_taxid(it) + if found: + return found + return None + + return _deep_find_taxid(rep) + + + +def extract_biosample_ids(rep: dict[str, Any]) -> list[str]: + """ + Extract BioSample IDs from a genome assembly report. + """ + accs = set() + + # --- Check standard paths --- + for path in [ + rep.get("assemblyInfo", {}).get("biosample"), + rep.get("assembly", {}).get("biosample"), + rep.get("biosample") + ]: + if isinstance(path, dict): + v = path.get("accession") or path.get("biosampleAccession") + if isinstance(v, str) and v.strip(): + accs.add(v.strip()) + elif isinstance(path, list): + for it in path: + if isinstance(it, dict): + v = it.get("accession") or it.get("biosampleAccession") + if isinstance(v, str) and v.strip(): + accs.add(v.strip()) + + # --- Regex fallback if no BioSamples found --- + if not accs: + accs.update(_deep_collect_regex(rep, PAT_BIOSAMPLE)) + + # --- Return unique + sorted IDs --- + return sorted(accs) + + + +def extract_bioproject_ids(rep: dict[str, Any]) -> list[str]: + accs = set() + for path in [ + rep.get("assemblyInfo", {}).get("bioproject"), + rep.get("assembly", {}).get("bioproject"), + rep.get("bioproject"), + ]: + if isinstance(path, dict): + v = path.get("accession") or path.get("bioprojectAccession") + if isinstance(v, str) and v.strip(): + accs.add(v.strip()) + elif isinstance(path, list): + for it in path: + if isinstance(it, dict): + v = it.get("accession") or it.get("bioprojectAccession") + if isinstance(v, str) and v.strip(): + accs.add(v.strip()) + if not accs: + accs.update(_deep_collect_regex(rep, PAT_BIOPROJECT)) + return sorted(accs) + + +def extract_assembly_accessions(rep: dict[str, Any]) -> tuple[list[str], list[str]]: + """ + Extract RefSeq (GCF_) and GenBank (GCA_) accession IDs from an assembly report. + - Consider only the top-level accession and assembly_info.paired_assembly. + - Classify accessions into GCF (RefSeq) and GCA (GenBank). + - Return sorted unique lists for both GCF and GCA. + + Returns: + A tuple of two lists: + - List of GCF accessions (RefSeq) + - List of GCA accessions (GenBank) + """ + + gcf, gca = set(), set() + + def _add_if_valid(acc: str): + """Helper: classify accession into GCF or GCA bucket.""" + if not isinstance(acc, str) or not acc.strip(): + return + acc = acc.strip() + if acc.startswith("GCF_"): + gcf.add(acc) + elif acc.startswith("GCA_"): + gca.add(acc) + + # --- Top-level accession --- + _add_if_valid(rep.get("accession")) + + # --- Paired assembly accession (from assembly_info) --- + ai = rep.get("assembly_info") or rep.get("assemblyInfo") or {} + paired = ai.get("paired_assembly", {}) + if isinstance(paired, dict): + _add_if_valid(paired.get("accession")) + + return sorted(gcf), sorted(gca) + + + +# ---------------- CDM Builders ---------------- +def build_cdm_datasource() -> pd.DataFrame: + """Generate a single CDM datasource record for NCBI RefSeq (via API).""" + record = { + "name": "RefSeq", + "source": "NCBI RefSeq", + "url": "https://api.ncbi.nlm.nih.gov/datasets/v2/genome/taxon/", + "accessed": date.today().isoformat(), + "version": "231", + } + return pd.DataFrame([record]) + + +## entity +CDM_NAMESPACE = uuid.UUID("11111111-2222-3333-4444-555555555555") + +def build_entity_id(key: str) -> str: + """ + Generate a deterministic CDM ID from a given key using UUIDv5. + """ + k = (key or "").strip() + return f"CDM:{uuid.uuid5(CDM_NAMESPACE, k)}" + + +def build_cdm_entity( + key_for_uuid: str, + created_date: str | None, + *, + entity_type: Literal["contig_collection", "genome", "protein", "gene"] = "contig_collection", + data_source: str = "RefSeq") -> tuple[pd.DataFrame, str]: + + """ + Build a CDM 'entity' row with CDM-style entity_id, creation, and update timestamps. + + Parameters: + - key_for_uuid: a unique identifier string to derive the CDM UUID + - created_date: original creation date (from source), or None to use today + - entity_type: CDM entity type (default: 'contig_collection') + - data_source: Source name for traceability (default: 'RefSeq') + + Returns: + - A tuple of (DataFrame with 1 row, entity_id string) + """ + + entity_id = build_entity_id(key_for_uuid) + record = { + "entity_id": entity_id, + "entity_type": entity_type, + "data_source": data_source, + "created": created_date or date.today().isoformat(), + "updated": datetime.now().isoformat(timespec="seconds"), + } + return pd.DataFrame([record]), entity_id + + +## contig_collection +def build_cdm_contig_collection(entity_id: str, taxid: str | None = None, collection_type: str = "isolate") -> pd.DataFrame: + """ + Build the contig_collection CDM table with taxon ID and collection type. + """ + return pd.DataFrame([{ + "collection_id": str(entity_id), + "contig_collection_type": str(collection_type), + "ncbi_taxon_id": f"NCBITaxon:{taxid}" if taxid else None, + "gtdb_taxon_id": None, # reserved for future GTDB support + }]) + + +def build_cdm_name_rows(entity_id: str, rep: dict[str, Any]) -> pd.DataFrame: + """ + Build two name rows: organism name and assembly name. + - Organism name comes from extract_organism_name() + - Assembly name comes from extract_assembly_name() + """ + rows = [] + + # ==== organism name ==== + full_org_name = extract_organism_name(rep) + if isinstance(full_org_name, str) and full_org_name.strip(): + rows.append({ + "entity_id": str(entity_id), + "name": full_org_name.strip(), + "description": "RefSeq organism name", + "source": "RefSeq" + }) + + # ==== assembly name ==== + asm_name = extract_assembly_name(rep) + if isinstance(asm_name, str) and asm_name.strip(): + rows.append({ + "entity_id": str(entity_id), + "name": asm_name.strip(), + "description": "RefSeq assembly name", + "source": "RefSeq" + }) + + return pd.DataFrame(rows) + + +# ---------------- Identifier Constants ---------------- +IDENTIFIER_PREFIXES = { + "biosample": ("Biosample", "BioSample ID"), + "bioproject": ("BioProject", "BioProject ID"), + "taxon": ("NCBITaxon", "NCBI Taxon ID"), + "gcf": ("ncbi.assembly", "NCBI Assembly ID"), + "gca": ("insdc.gca", "GenBank Assembly ID"), + "gcf_unit": ("insdc.gca", "RefSeq Unit Assembly ID"), + "gca_unit": ("insdc.gca", "GenBank Unit Assembly ID"), +} + + + +def build_cdm_identifier_rows(entity_id: str, rep: dict, request_taxid: str | None) -> list[dict]: + """ + Build CDM 'identifier' rows from parsed NCBI RefSeq metadata. + """ + + rows = [] + + # ---- BioSample IDs ---- + for bs in extract_biosample_ids(rep): + rows.append({ + "entity_id": entity_id, + "identifier": f"{IDENTIFIER_PREFIXES['biosample'][0]}:{bs.strip()}", + "source": "RefSeq", + "description": IDENTIFIER_PREFIXES['biosample'][1], + }) + + # ---- BioProject IDs ---- + for bp in extract_bioproject_ids(rep): + rows.append({ + "entity_id": entity_id, + "identifier": f"{IDENTIFIER_PREFIXES['bioproject'][0]}:{bp.strip()}", + "source": "RefSeq", + "description": IDENTIFIER_PREFIXES['bioproject'][1], + }) + + # ---- Taxon ID ---- + tx = extract_taxid(rep) or (str(request_taxid).strip() if request_taxid else None) + if tx and tx.isdigit(): + rows.append({ + "entity_id": entity_id, + "identifier": f"{IDENTIFIER_PREFIXES['taxon'][0]}:{tx}", + "source": "RefSeq", + "description": IDENTIFIER_PREFIXES['taxon'][1], + }) + + # ---- Assembly Accessions (GCF / GCA) ---- + gcf_list, gca_list = extract_assembly_accessions(rep) + for gcf in gcf_list: + rows.append({ + "entity_id": entity_id, + "identifier": f"{IDENTIFIER_PREFIXES['gcf'][0]}:{gcf.strip()}", + "source": "RefSeq", + "description": IDENTIFIER_PREFIXES['gcf'][1], + }) + for gca in gca_list: + rows.append({ + "entity_id": entity_id, + "identifier": f"{IDENTIFIER_PREFIXES['gca'][0]}:{gca.strip()}", + "source": "RefSeq", + "description": IDENTIFIER_PREFIXES['gca'][1], + }) + + # ---- Deduplicate ---- + seen = set() + uniq = [] + for r in rows: + ident = r["identifier"] + + k = (r["entity_id"], ident) + if k not in seen: + seen.add(k) + uniq.append(r) + + return uniq + + +# ---------------- CLI ---------------- +@click.command() +@click.option("--taxid", required=True, + help="Comma-separated NCBI TaxIDs, e.g. '224325,2741724'.") +@click.option("--api-key", default=None, + help="Optional NCBI API key (increases rate limits).") +@click.option("--database", default="refseq_api", show_default=True, + help="Delta schema/database.") +@click.option("--mode", default="overwrite", type=click.Choice(["overwrite", "append"]), show_default=True, + help="Write mode for Delta tables.") +@click.option("--debug/--no-debug", default=False, show_default=True, + help="Print per-record parsed fields for troubleshooting.") +@click.option("--allow-genbank-date/--no-allow-genbank-date", default=False, show_default=True, + help="Allow using GenBank submissionDate as fallback for RefSeq created date.") +@click.option("--unique-per-taxon/--all-assemblies", default=False, show_default=True, + help="Keep only one assembly per taxon (latest by release_date).") + + +def cli(taxid, api_key, database, mode, debug, allow_genbank_date, unique_per_taxon): + main( + taxid=taxid, + api_key=api_key, + database=database, + mode=mode, + debug=debug, + allow_genbank_date=allow_genbank_date, + unique_per_taxon=unique_per_taxon + ) + + +def process_report(rep: dict, tx: str, seen: set, debug: bool, allow_genbank_date: bool): + """ + Process a single assembly report and return partial CDM tables (entity, collection, names, identifiers). + Skip duplicates using 'seen'. + """ + entities, collections, names, identifiers = [], [], [], [] + + # === accession (prefer GCF, fallback to GCA) === + gcf_list, gca_list = extract_assembly_accessions(rep) + acc = gcf_list[0] if gcf_list else (gca_list[0] if gca_list else None) + + # === assembly + organism names === + asm_name = extract_assembly_name(rep) + org_name = extract_organism_name(rep) + + # === creation date === + created = extract_created_date(rep, allow_genbank_date=allow_genbank_date) + if not created: + if debug: + print(f"[WARN] No RefSeq date for accession={acc}") + created = date.today().isoformat() + + # === unique key for entity === + key = _coalesce(acc, asm_name, org_name, tx) + if not key or key in seen: + return entities, collections, names, identifiers + seen.add(key) + + # ---------------- CDM tables ---------------- + df_entity, entity_id = build_cdm_entity(key, created) + entities.append(df_entity) + + collections.append(build_cdm_contig_collection(entity_id, taxid=tx)) + + rows_name = build_cdm_name_rows(entity_id, rep) + if not rows_name.empty: + names.extend(rows_name.to_dict(orient="records")) + + rows_id = build_cdm_identifier_rows(entity_id, rep, tx) + identifiers.extend(rows_id) + + return entities, collections, names, identifiers + + +def process_taxon(tx: str, api_key: str, debug: bool, allow_genbank_date: bool, unique_per_taxon: bool, seen: set): + """ + Process all reports for a given TaxID and return combined partial tables. + """ + entities, collections, names, identifiers = [], [], [], [] + + reports = list(fetch_reports_by_taxon(taxon=tx, api_key=api_key)) + + # If only unique assembly per taxon → keep latest + if unique_per_taxon and reports: + reports.sort(key=lambda r: extract_created_date(r, allow_genbank_date) or "0000-00-00", reverse=True) + reports = [reports[0]] + + for rep in reports: + e, c, n, i = process_report(rep, tx, seen, debug, allow_genbank_date) + entities.extend(e) + collections.extend(c) + names.extend(n) + identifiers.extend(i) + + return entities, collections, names, identifiers + + + +def finalize_tables(entities, collections, names, identifiers): + """ + Concatenate and deduplicate CDM tables. + """ + def _concat(frames): return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame() + def _dedup(pdf, cols): return pdf.drop_duplicates(subset=cols) if not pdf.empty else pdf + + pdf_entity = _dedup(_concat(entities), ["entity_id"]) + pdf_coll = _dedup(_concat(collections), ["collection_id"]) + pdf_name = _dedup(pd.DataFrame(names), ["entity_id", "name"]) if names else pd.DataFrame() + pdf_ident = _dedup(pd.DataFrame(identifiers), ["entity_id", "identifier"]) if identifiers else pd.DataFrame() + + return pdf_entity, pdf_coll, pdf_name, pdf_ident + + + +def write_and_preview(spark, database, mode, pdf_entity, pdf_coll, pdf_name, pdf_ident): + """ + Write tables to Delta and preview results. + """ + write_delta(spark, pdf_entity, database, "entity", mode) + write_delta(spark, pdf_coll, database, "contig_collection", mode) + write_delta(spark, pdf_name, database, "name", mode) + write_delta(spark, pdf_ident, database, "identifier", mode) + + print("\nDelta tables written:") + for tbl in ["datasource", "entity", "contig_collection", "name", "identifier"]: + preview_or_skip(spark, database, tbl) + + + +def main(taxid, api_key, database, mode, debug, allow_genbank_date=False, unique_per_taxon=False): + spark = build_spark(database) + + # datasource table (fixed record) + df_ds = build_cdm_datasource() + write_delta(spark, df_ds, database, "datasource", mode) + + entities, collections, names, identifiers = [], [], [], [] + seen = set() + + taxids = [t.strip() for t in taxid.split(",") if t.strip()] + print(f"Using TaxIDs: {taxids}") + + for tx in taxids: + print(f"Fetching taxon={tx}") + e, c, n, i = process_taxon(tx, api_key, debug, allow_genbank_date, unique_per_taxon, seen) + + # extend results into global containers + entities.extend(e) + collections.extend(c) + names.extend(n) + identifiers.extend(i) + + pdf_entity, pdf_coll, pdf_name, pdf_ident = finalize_tables(entities, collections, names, identifiers) + write_and_preview(spark, database, mode, pdf_entity, pdf_coll, pdf_name, pdf_ident) + + +if __name__ == "__main__": + cli() + + + + + From 13e7df376a279d2d631079eb6690ec41ee6752a7 Mon Sep 17 00:00:00 2001 From: alinakbase Date: Wed, 3 Sep 2025 11:48:35 -0700 Subject: [PATCH 2/2] Add files via upload --- tests/test_refseq_api.py | 721 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 721 insertions(+) create mode 100644 tests/test_refseq_api.py diff --git a/tests/test_refseq_api.py b/tests/test_refseq_api.py new file mode 100644 index 0000000..a0c32de --- /dev/null +++ b/tests/test_refseq_api.py @@ -0,0 +1,721 @@ +import uuid +import click +import pytest +import requests +import pandas as pd +from datetime import date, datetime +from unittest.mock import patch, MagicMock +from refseq_api import fetch_reports_by_taxon +from refseq_api import _coalesce, _deep_find_str, _deep_collect_regex, PAT_BIOSAMPLE, PAT_BIOPROJECT, PAT_GCF, PAT_GCA +from refseq_api import extract_created_date +from refseq_api import extract_assembly_name +from refseq_api import extract_organism_name +from refseq_api import extract_taxid +from refseq_api import extract_biosample_ids +from refseq_api import extract_bioproject_ids +from refseq_api import extract_assembly_accessions +from refseq_api import build_cdm_datasource +from refseq_api import build_entity_id, CDM_NAMESPACE +from refseq_api import build_cdm_entity +from refseq_api import build_cdm_contig_collection +from refseq_api import build_cdm_name_rows +from refseq_api import build_cdm_identifier_rows +from refseq_api import process_report, process_taxon, finalize_tables, write_and_preview, main +from refseq_api import parse_taxid_args + + +@pytest.mark.parametrize( + "taxid_arg, taxid_file_content, expected", + [ + ("224325", None, ["224325"]), + ("224325,2741724", None, ["224325", "2741724"]), + ("TaxID:224325, abc2741724", None, ["224325", "2741724"]), + ("224325,224325,2741724", None, ["224325", "2741724"]), + ("", None, [])] +) + + +def test_parse_taxid_args_inline(taxid_arg, taxid_file_content, expected, tmp_path): + taxid_file = None + if taxid_file_content: + taxid_file = tmp_path / "taxids.txt" + taxid_file.write_text(taxid_file_content) + assert parse_taxid_args(taxid_arg, taxid_file) == expected + + +def test_parse_taxid_file_not_found(): + with pytest.raises(click.BadParameter): + parse_taxid_args(None, "nonexistent.txt") + + +def make_response(reports, next_token=None): + payload = {"reports": reports} + if next_token: + payload["next_page_token"] = next_token + mock_resp = MagicMock() + mock_resp.json.return_value = payload + mock_resp.raise_for_status.return_value = None + return mock_resp + + +@pytest.mark.parametrize( + "mock_reports, side_effect, expected_accessions", + [ + # RefSeq only: should keep only RefSeq record + ( + [ + {"accession": "GCA_123", "assemblyInfo": {"sourceDatabase": "SOURCE_DATABASE_GENBANK"}}, + {"accession": "GCF_456", "assemblyInfo": {"sourceDatabase": "SOURCE_DATABASE_REFSEQ"}}, + ], + None, + ["GCF_456"] + ), + # Empty reports + ([], None, []), + # Pagination + ( + [ + {"accession": "GCF_1", "assemblyInfo": {"sourceDatabase": "SOURCE_DATABASE_REFSEQ"}} + ], + None, # side_effect controlled separately + ["GCF_1", "GCF_2"] + ), + ] +) +@patch("refseq_api.requests.get") +def test_fetch_reports(mock_get, mock_reports, side_effect, expected_accessions): + if expected_accessions == ["GCF_1", "GCF_2"]: + first_page = make_response(mock_reports, next_token="token123") + second_page = make_response( + [{"accession": "GCF_2", "assemblyInfo": {"sourceDatabase": "SOURCE_DATABASE_REFSEQ"}}] + ) + mock_get.side_effect = [first_page, second_page] + elif side_effect: + mock_get.side_effect = side_effect + else: + mock_get.return_value = make_response(mock_reports) + + results = list(fetch_reports_by_taxon("1234")) + accessions = [r["accession"] for r in results] + assert accessions == expected_accessions + + +# Network error +@patch("refseq_api.requests.get") +def test_network_error(mock_get): + mock_get.side_effect = requests.RequestException("Network down") + results = list(fetch_reports_by_taxon("1234")) + assert results == [] + + +# -------------------- _coalesce -------------------- +@pytest.mark.parametrize("inputs, expected", [ + (["", " ", None, "abc"], "abc"), # should pick first non-empty + ([None, " ", "xyz", "zzz"], "xyz"), + ([None, "", " "], None), # all empty +]) +def test_coalesce(inputs, expected): + assert _coalesce(*inputs) == expected + + +# -------------------- _deep_find_str -------------------- +@pytest.mark.parametrize("obj, keys, expected", [ + ({"assemblyDate": "2000-12-01"}, {"assemblyDate"}, "2000-12-01"), + ({"nested": {"releaseDate": "2010-01-01"}}, {"releaseDate"}, "2010-01-01"), + ({"list": [{"submissionDate": "2020-05-05"}]}, {"submissionDate"}, "2020-05-05"), + ({"noDate": "xxx"}, {"releaseDate"}, None), +]) +def test_deep_find_str(obj, keys, expected): + assert _deep_find_str(obj, keys) == expected + + +# -------------------- _deep_collect_regex -------------------- +@pytest.mark.parametrize("obj, pattern, expected", [ + ("Biosample SAMN12345 here", PAT_BIOSAMPLE, ["SAMN12345"]), + ({"a": "Project PRJNA99999"}, PAT_BIOPROJECT, ["PRJNA99999"]), + (["Genome GCF_000123456.1"], PAT_GCF, ["GCF_000123456.1"]), + ({"list": ["Some GCA_000987654.2", "Other GCA_000987654.2"]}, PAT_GCA, ["GCA_000987654.2"]), # dedup + ("No matches here", PAT_BIOSAMPLE, []), +]) +def test_deep_collect_regex(obj, pattern, expected): + assert _deep_collect_regex(obj, pattern) == expected + + +@pytest.mark.parametrize( + "rep, allow_genbank, expected", + [ + # --- RefSeq record: release_date present --- + ( + {"assemblyInfo": {"sourceDatabase": "SOURCE_DATABASE_REFSEQ", "releaseDate": "2020-01-01"}}, + False, + "2020-01-01", + ), + # --- RefSeq record: no release_date, use assembly_date --- + ( + {"assemblyInfo": {"sourceDatabase": "SOURCE_DATABASE_REFSEQ", "assemblyDate": "2021-05-05"}}, + False, + "2021-05-05", + ), + # --- RefSeq record: no release/assembly, use submission_date --- + ( + {"assemblyInfo": {"sourceDatabase": "SOURCE_DATABASE_REFSEQ", "submissionDate": "2022-03-03"}}, + False, + "2022-03-03", + ), + # --- GenBank record: allow_genbank_date = False, expect None --- + ( + {"assemblyInfo": {"sourceDatabase": "SOURCE_DATABASE_GENBANK", "submissionDate": "2019-09-09"}}, + False, + None, + ), + # --- GenBank record: allow_genbank_date = True, expect submission_date --- + ( + {"assemblyInfo": {"sourceDatabase": "SOURCE_DATABASE_GENBANK", "submissionDate": "2019-09-09"}}, + True, + "2019-09-09", + ), + # --- No dates at all --- + ( + {"assemblyInfo": {"sourceDatabase": "SOURCE_DATABASE_REFSEQ"}}, + False, + None, + ), + ] +) + +def test_extract_created_date(rep, allow_genbank, expected): + result = extract_created_date(rep, allow_genbank_date=allow_genbank, debug=True) + assert result == expected + + +@pytest.mark.parametrize( + "rep, expected", + [ + # Top-level assemblyName + ({"assemblyName": "ASM1234v1"}, "ASM1234v1"), + # Inside assemblyInfo + ({"assemblyInfo": {"assemblyName": "ASM5678v1"}}, "ASM5678v1"), + # Inside assembly + ({"assembly": {"assemblyName": "ASM9999v1"}}, "ASM9999v1"), + # Deep nested (fallback to _deep_find_str) + ({"meta": {"nested": {"assembly_name": "ASM_DEEPv1"}}}, "ASM_DEEPv1"), + # No name available + ({"assemblyInfo": {}, "assembly": {}}, None), + ] +) + +def test_extract_assembly_name(rep, expected): + assert extract_assembly_name(rep) == expected + + +@pytest.mark.parametrize( + "rep, expected", + [ + # Top-level organismName + ({"organism": {"organismName": "Haloferax volcanii"}}, "Haloferax volcanii"), + # Top-level scientificName + ({"organism": {"scientificName": "Methanococcus maripaludis"}}, "Methanococcus maripaludis"), + # Top-level taxName + ({"organism": {"taxName": "Archaeoglobus fulgidus"}}, "Archaeoglobus fulgidus"), + # assemblyInfo.organism.organismName + ({"assemblyInfo": {"organism": {"organismName": "Sulfolobus islandicus"}}}, "Sulfolobus islandicus"), + # assembly.organism.organismName + ({"assembly": {"organism": {"organismName": "Thermococcus kodakarensis"}}}, "Thermococcus kodakarensis"), + # Deep nested key (should trigger _deep_find_str) + ({"meta": {"data": {"deep": {"scientificName": "Pyrococcus furiosus"}}}}, "Pyrococcus furiosus"), + # No name available + ({"assemblyInfo": {}, "assembly": {}, "organism": {}}, None), + ] +) + +def test_extract_organism_name(rep, expected): + assert extract_organism_name(rep) == expected + + +@pytest.mark.parametrize( + "rep, expected", + [ + # Top-level taxId as int + ({"organism": {"taxId": 12345}}, "12345"), + + # Top-level taxid as string + ({"organism": {"taxid": "67890"}}, "67890"), + + # Top-level taxID as float + ({"organism": {"taxID": 11111.0}}, "11111"), + + # Nested dict with taxid + ({"meta": {"organism": {"taxid": "22222"}}}, "22222"), + + # Nested deeply inside list + ({"organisms": [{"nested": {"tax_id": "33333"}}]}, "33333"), + + # Invalid taxid (non-digit string) + ({"organism": {"taxId": "abc"}}, None), + + # Missing entirely + ({"organism": {}}, None), + ] +) + +def test_extract_taxid(rep, expected): + assert extract_taxid(rep) == expected + + +@pytest.mark.parametrize( + "rep, expected", + [ + # biosample in assemblyInfo dict + ( + {"assemblyInfo": {"biosample": {"accession": "SAMN12345"}}}, + ["SAMN12345"] + ), + + # biosample in assembly dict + ( + {"assembly": {"biosample": {"biosampleAccession": "SAMN67890"}}}, + ["SAMN67890"] + ), + + # biosample at top level + ( + {"biosample": {"accession": "SAMN11111"}}, + ["SAMN11111"] + ), + + # biosample as list of dicts + ( + {"biosample": [{"accession": "SAMN22222"}, {"biosampleAccession": "SAMN33333"}]}, + ["SAMN22222", "SAMN33333"] + ), + + # fallback to regex + ( + {"note": "Sample accession SAMN44444 found in text"}, + ["SAMN44444"] + ), + + # no biosample at all + ( + {"assemblyInfo": {}, "assembly": {}}, + [] + ), + ] +) + +def test_extract_biosample_ids(rep, expected): + result = extract_biosample_ids(rep) + assert result == expected + + +@pytest.mark.parametrize( + "rep, expected", + [ + # bioproject in assemblyInfo dict + ( + {"assemblyInfo": {"bioproject": {"accession": "PRJNA12345"}}}, + ["PRJNA12345"] + ), + + # bioproject in assembly dict + ( + {"assembly": {"bioproject": {"bioprojectAccession": "PRJNA67890"}}}, + ["PRJNA67890"] + ), + + # bioproject at top level + ( + {"bioproject": {"accession": "PRJNA11111"}}, + ["PRJNA11111"] + ), + + # bioproject as list of dicts + ( + {"bioproject": [{"accession": "PRJNA22222"}, {"bioprojectAccession": "PRJNA33333"}]}, + ["PRJNA22222", "PRJNA33333"] + ), + + # fallback to regex + ( + {"note": "Bioproject accession PRJNA44444 found in free text"}, + ["PRJNA44444"] + ), + + # no bioproject at all + ( + {"assemblyInfo": {}, "assembly": {}}, + [] + ), + ] +) + +def test_extract_bioproject_ids(rep, expected): + result = extract_bioproject_ids(rep) + assert result == expected + + +@pytest.mark.parametrize( + "rep, expected_gcf, expected_gca", + [ + # Top-level accession is GCF + ({"accession": "GCF_000123456.1"}, ["GCF_000123456.1"], []), + + # Top-level accession is GCA + ({"accession": "GCA_000654321.1"}, [], ["GCA_000654321.1"]), + + # Paired assembly is GCF + ( + {"assemblyInfo": {"paired_assembly": {"accession": "GCF_999999999.1"}}}, + ["GCF_999999999.1"], + [] + ), + + # Paired assembly is GCA + ( + {"assembly_info": {"paired_assembly": {"accession": "GCA_888888888.1"}}}, + [], + ["GCA_888888888.1"] + ), + + # Both top-level and paired (mixed GCF + GCA) + ( + { + "accession": "GCF_000123456.1", + "assemblyInfo": {"paired_assembly": {"accession": "GCA_000654321.1"}}, + }, + ["GCF_000123456.1"], + ["GCA_000654321.1"], + ), + + # Invalid accession (not GCF/GCA) → ignore + ({"accession": "XYZ_123"}, [], []), + + # Empty dict + ({}, [], []), + ] +) + +def test_extract_assembly_accessions(rep, expected_gcf, expected_gca): + gcf, gca = extract_assembly_accessions(rep) + assert gcf == expected_gcf + assert gca == expected_gca + + +@pytest.mark.parametrize( + "expected_name, expected_source, expected_url, expected_version", + [ + ( + "RefSeq", + "NCBI RefSeq", + "https://api.ncbi.nlm.nih.gov/datasets/v2/genome/taxon/", + "231", + ) + ] +) + +def test_build_cdm_datasource(expected_name, expected_source, expected_url, expected_version): + df = build_cdm_datasource() + + # Should return exactly one row + assert isinstance(df, pd.DataFrame) + assert len(df) == 1 + + record = df.iloc[0].to_dict() + + # Fixed fields + assert record["name"] == expected_name + assert record["source"] == expected_source + assert record["url"] == expected_url + assert record["version"] == expected_version + + # Accessed date should equal today's date + assert record["accessed"] == date.today().isoformat() + + +@pytest.mark.parametrize( + "key, expected_uuid", + [ + # Normal case: deterministic UUID + ( + "GCF_000008665.1", + f"CDM:{uuid.uuid5(CDM_NAMESPACE, 'GCF_000008665.1')}", + ), + + # Leading/trailing whitespace should not affect result + ( + " GCF_000008665.1 ", + f"CDM:{uuid.uuid5(CDM_NAMESPACE, 'GCF_000008665.1')}", + ), + + # Empty string should still return a valid UUID + ( + "", + f"CDM:{uuid.uuid5(CDM_NAMESPACE, '')}", + ), + + # Another different key must yield a different UUID + ( + "ASM12345v1", + f"CDM:{uuid.uuid5(CDM_NAMESPACE, 'ASM12345v1')}", + ), + ] +) + +def test_build_entity_id(key, expected_uuid): + result = build_entity_id(key) + assert result == expected_uuid + assert result.startswith("CDM:") + assert len(result) > 10 + + +@pytest.mark.parametrize( + "key, created_date, entity_type, data_source", + [ + ("GCF_000008665.1", "2000-12-01", "contig_collection", "RefSeq"), + ("ASM12345v1", None, "genome", "CustomSource"), + ], +) +def test_build_cdm_entity(key, created_date, entity_type, data_source): + # Run function + df, entity_id = build_cdm_entity( + key_for_uuid=key, + created_date=created_date, + entity_type=entity_type, + data_source=data_source) + + # --- General checks --- + assert isinstance(df, pd.DataFrame) + assert len(df) == 1 # exactly one row + assert entity_id == build_entity_id(key) # UUID must match + + # --- Row content checks --- + row = df.iloc[0] + assert row["entity_id"] == entity_id + assert row["entity_type"] == entity_type + assert row["data_source"] == data_source + + # If created_date was provided, must match; else today + expected_created = created_date or date.today().isoformat() + assert row["created"] == expected_created + + # updated must be a valid ISO8601 timestamp (seconds precision) + try: + datetime.fromisoformat(row["updated"]) + except ValueError: + pytest.fail(f"Invalid updated timestamp: {row['updated']}") + + +@pytest.mark.parametrize( + "entity_id, taxid, collection_type, expected", + [ + ( + "CDM:1234", "224325", "isolate", + {"collection_id": "CDM:1234", "contig_collection_type": "isolate", "ncbi_taxon_id": "NCBITaxon:224325", "gtdb_taxon_id": None} + ), + ( + "CDM:5678", None, "metagenome", + {"collection_id": "CDM:5678", "contig_collection_type": "metagenome", "ncbi_taxon_id": None, "gtdb_taxon_id": None} + ), + ] +) + +def test_build_cdm_contig_collection(entity_id, taxid, collection_type, expected): + df = build_cdm_contig_collection(entity_id, taxid, collection_type) + assert len(df) == 1 + row = df.iloc[0].to_dict() + assert row == expected + + +@pytest.mark.parametrize( + "entity_id, rep, expected_rows", + [ + ( + "CDM:1234", + {"organism": {"organismName": "Haloarcula marismortui"}, + "assemblyInfo": {"assemblyName": "ASM1234v1"}}, + [ + { + "entity_id": "CDM:1234", + "name": "Haloarcula marismortui", + "description": "RefSeq organism name", + "source": "RefSeq" + }, + { + "entity_id": "CDM:1234", + "name": "ASM1234v1", + "description": "RefSeq assembly name", + "source": "RefSeq" + } + ] + ), + ( + "CDM:5678", + {"organism": {"scientificName": "Methanocaldococcus jannaschii"}}, + [ + { + "entity_id": "CDM:5678", + "name": "Methanocaldococcus jannaschii", + "description": "RefSeq organism name", + "source": "RefSeq" + } + ] + ), + ( + "CDM:9012", + {"assemblyInfo": {"assemblyName": "ASM9012v1"}}, + [ + { + "entity_id": "CDM:9012", + "name": "ASM9012v1", + "description": "RefSeq assembly name", + "source": "RefSeq" + } + ] + ), + ( + "CDM:0000", + {}, + [] + ), + ] +) + +def test_build_cdm_name_rows(entity_id, rep, expected_rows): + df = build_cdm_name_rows(entity_id, rep) + result = df.to_dict(orient="records") + assert result == expected_rows + + +ENTITY_ID = "CDM:12345" + +@pytest.mark.parametrize( + "rep, request_taxid, expected_identifiers", + [ + # ---- BioSample ---- + ( + {"assemblyInfo": {"biosample": {"accession": "SAMN123"}}}, + None, + ["Biosample:SAMN123"], + ), + + # ---- BioProject ---- + ( + {"assemblyInfo": {"bioproject": {"accession": "PRJNA456"}}}, + None, + ["BioProject:PRJNA456"], + ), + + # ---- Taxon from rep ---- + ( + {"organism": {"taxId": "789"}}, + None, + ["NCBITaxon:789"], + ), + + # ---- Taxon from request_taxid fallback ---- + ( + {}, + "2468", + ["NCBITaxon:2468"], + ), + + # ---- Assembly Accessions (GCF + GCA) ---- + ( + {"accession": "GCF_000001.1"}, + None, + ["ncbi.assembly:GCF_000001.1"], + ), + ( + {"accession": "GCA_000002.1"}, + None, + ["insdc.gca:GCA_000002.1"], + ), + + # ---- Deduplication (same identifier twice) ---- + ( + {"assemblyInfo": {"biosample": {"accession": "SAMN999"}}, + "assembly": {"biosample": {"biosampleAccession": "SAMN999"}}}, + None, + ["Biosample:SAMN999"], + ), + ] +) + +def test_build_cdm_identifier_rows(rep, request_taxid, expected_identifiers): + rows = build_cdm_identifier_rows(ENTITY_ID, rep, request_taxid) + identifiers = [r["identifier"] for r in rows] + assert identifiers == expected_identifiers + for r in rows: + assert r["entity_id"] == ENTITY_ID + assert r["source"] == "RefSeq" + + +# ---- dummy data ---- +dummy_rep = { + "accession": "GCF_000001", + "assemblyInfo": {"assemblyName": "ASM1", "sourceDatabase": "SOURCE_DATABASE_REFSEQ"}, + "organism": {"organismName": "Testus organism", "taxId": 1234}, + "biosample": {"accession": "SAMN0001"}, + "bioproject": {"accession": "PRJNA0001"} + } + + +def test_process_report_new_entity(): + seen = set() + e, c, n, i = process_report(dummy_rep, "1234", seen, debug=False, allow_genbank_date=False) + # entity dataframe should have 1 row + assert len(e) == 1 + assert not e[0].empty + # collections should include taxid + assert any("NCBITaxon:1234" in str(row) for row in c[0].to_dict("records")) + # names contain organism and assembly + assert any("Testus organism" in row["name"] for row in n) + # identifiers contain accession + assert any("SAMN0001" in row["identifier"] for row in i) + + +def test_process_report_duplicate(): + seen = {"GCF_000001"} # pre-fill key + e, c, n, i = process_report(dummy_rep, "1234", seen, debug=False, allow_genbank_date=False) + # should skip + assert e == [] and c == [] and n == [] and i == [] + + +@patch("refseq_api.fetch_reports_by_taxon") +def test_process_taxon_unique(mock_fetch): + rep1 = dict(dummy_rep) + rep2 = dict(dummy_rep); rep2["accession"] = "GCF_000002" + mock_fetch.return_value = [rep1, rep2] + + seen = set() + e, c, n, i = process_taxon("1234", api_key=None, debug=False, allow_genbank_date=False, unique_per_taxon=True, seen=seen) + assert len(e) == 1 # only one kept + + +def test_finalize_tables_dedup(): + df1 = pd.DataFrame([{"entity_id": "1"}]) + df2 = pd.DataFrame([{"entity_id": "1"}]) # duplicate + pdf_entity, _, _, _ = finalize_tables([df1, df2], [], [], []) + assert len(pdf_entity) == 1 + + +@patch("refseq_api.write_delta") +@patch("refseq_api.preview_or_skip") + +def test_write_and_preview(mock_preview, mock_write): + spark = MagicMock() + pdf = pd.DataFrame([{"entity_id": "1"}]) + write_and_preview(spark, "db", "overwrite", pdf, pdf, pdf, pdf) + assert mock_write.call_count == 4 + assert mock_preview.call_count == 5 + + +@patch("refseq_api.build_spark") +@patch("refseq_api.write_delta") +@patch("refseq_api.fetch_reports_by_taxon") + +def test_main_orchestration(mock_fetch, mock_write, mock_build): + mock_build.return_value = MagicMock() + mock_fetch.return_value = [dummy_rep] + main("1234", None, "db", "overwrite", debug=False) + assert mock_write.call_count >= 1