Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ingest): decrease ingest memory usage for large datasets #3505

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
18 changes: 9 additions & 9 deletions ingest/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ rule prepare_metadata:
sequence_hashes="results/sequence_hashes.ndjson",
config="results/config.yaml",
output:
metadata="results/metadata_post_prepare.json",
metadata="results/metadata_post_prepare.ndjson",
params:
log_level=LOG_LEVEL,
shell:
Expand All @@ -328,11 +328,11 @@ rule prepare_metadata:
rule group_segments:
input:
script="scripts/group_segments.py",
metadata="results/metadata_post_prepare.json",
metadata="results/metadata_post_prepare.ndjson",
sequences="results/sequences.ndjson",
config="results/config.yaml",
output:
metadata="results/metadata_post_group.json",
metadata="results/metadata_post_group.ndjson",
sequences="results/sequences_post_group.ndjson",
params:
log_level=LOG_LEVEL,
Expand Down Expand Up @@ -368,9 +368,9 @@ rule get_previous_submissions:
# By delaying the start of the script
script="scripts/call_loculus.py",
prepped_metadata=(
"results/metadata_post_group.json"
"results/metadata_post_group.ndjson"
if SEGMENTED
else "results/metadata_post_prepare.json"
else "results/metadata_post_prepare.ndjson"
),
config="results/config.yaml",
output:
Expand All @@ -395,9 +395,9 @@ rule compare_hashes:
config="results/config.yaml",
old_hashes="results/previous_submissions.json",
metadata=(
"results/metadata_post_group.json"
"results/metadata_post_group.ndjson"
if SEGMENTED
else "results/metadata_post_prepare.json"
else "results/metadata_post_prepare.ndjson"
),
output:
to_submit="results/to_submit.json",
Expand Down Expand Up @@ -431,9 +431,9 @@ rule prepare_files:
script="scripts/prepare_files.py",
config="results/config.yaml",
metadata=(
"results/metadata_post_group.json"
"results/metadata_post_group.ndjson"
if SEGMENTED
else "results/metadata_post_prepare.json"
else "results/metadata_post_prepare.ndjson"
),
sequences=(
"results/sequences_post_group.ndjson"
Expand Down
2 changes: 0 additions & 2 deletions ingest/scripts/call_loculus.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,6 @@ def get_submitted(config: Config):
logger.info(f"Backend has status of: {len(statuses)} sequence entries from ingest")
logger.info(f"Ingest has submitted: {len(entries)} sequence entries to ingest")

logger.debug(entries)
logger.debug(statuses)
for entry in entries:
loculus_accession = entry["accession"]
submitter = entry["submitter"]
Expand Down
6 changes: 4 additions & 2 deletions ingest/scripts/compare_hashes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any

import click
import orjsonl
import requests
import yaml

Expand Down Expand Up @@ -171,7 +172,6 @@ def main(
config.debug_hashes = True

submitted: dict = json.load(open(old_hashes, encoding="utf-8"))
new_metadata = json.load(open(metadata, encoding="utf-8"))

update_manager = SequenceUpdateManager(
submit=[],
Expand All @@ -184,7 +184,9 @@ def main(
config=config,
)

for fasta_id, record in new_metadata.items():
for field in orjsonl.stream(metadata):
fasta_id = field["id"]
record = field["metadata"]
if not config.segmented:
insdc_accession_base = record["insdcAccessionBase"]
if not insdc_accession_base:
Expand Down
33 changes: 17 additions & 16 deletions ingest/scripts/group_segments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Script to group segments together into sequence entries prior to submission to Loculus
Example output for a single isolate with 3 segments:
"KJ682796.1.L/KJ682809.1.M/KJ682819.1.S": {
Example ndjson output for a single isolate with 3 segments:
{"id": "KJ682796.1.L/KJ682809.1.M/KJ682819.1.S",
"metadata": {
"ncbiReleaseDate": "2014-07-06T00:00:00Z",
"ncbiSourceDb": "GenBank",
"authors": "D. Goedhals, F.J. Burt, J.T. Bester, R. Swanepoel",
Expand All @@ -15,7 +16,7 @@
"hash_S": "f716ed13dca9c8a033d46da2f3dc2ff1",
"hash": "ce7056d0bd7e3d6d3eca38f56b9d10f8",
"submissionId": "KJ682796.1.L/KJ682809.1.M/KJ682819.1.S"
},"""
}}"""

import hashlib
import json
Expand Down Expand Up @@ -100,9 +101,11 @@ def main(
segments = config.nucleotide_sequences
number_of_segments = len(segments)

with open(input_metadata, encoding="utf-8") as file:
segment_metadata: dict[str, dict[str, str]] = json.load(file)
number_of_segmented_records = len(segment_metadata.keys())
number_of_segmented_records = 0
segment_metadata: dict[str, dict[str, str]] = {}
for record in orjsonl.stream(input_metadata):
segment_metadata[record["id"]] = record["metadata"]
number_of_segmented_records += 1
logger.info(f"Found {number_of_segmented_records} individual segments in metadata file")

# Group segments according to isolate, collection date and isolate specific values
Expand Down Expand Up @@ -174,7 +177,7 @@ def main(
number_of_groups = len(grouped_accessions)
group_lower_bound = number_of_segmented_records // number_of_segments
group_upper_bound = number_of_segmented_records
logging.info(f"Total of {number_of_groups} groups left after merging")
logger.info(f"Total of {number_of_groups} groups left after merging")
if number_of_groups < group_lower_bound:
raise ValueError(
{
Expand All @@ -192,11 +195,11 @@ def main(
}
)

# Add segment specific metadata for the segments
metadata: dict[str, dict[str, str]] = {}
# Map from original accession to the new concatenated accession
fasta_id_map: dict[Accession, Accession] = {}

count = 0

for group in grouped_accessions:
# Create key by concatenating all accession numbers with their segments
# e.g. AF1234_S/AF1235_M/AF1236_L
Expand Down Expand Up @@ -241,12 +244,10 @@ def main(
json.dumps(filtered_record, sort_keys=True).encode(), usedforsecurity=False
).hexdigest()

metadata[joint_key] = row
orjsonl.append(output_metadata, {"id": joint_key, "metadata": row})
count += 1

Path(output_metadata).write_text(
json.dumps(metadata, indent=4, sort_keys=True), encoding="utf-8"
)
logging.info(f"Wrote grouped metadata for {len(metadata)} sequences")
logger.info(f"Wrote grouped metadata for {count} sequences")

count = 0
count_ignored = 0
Expand All @@ -265,8 +266,8 @@ def main(
},
)
count += 1
logging.info(f"Wrote {count} sequences")
logging.info(f"Ignored {count_ignored} sequences as not found in {input_seq}")
logger.info(f"Wrote {count} sequences")
logger.info(f"Ignored {count_ignored} sequences as not found in {input_seq}")


if __name__ == "__main__":
Expand Down
85 changes: 51 additions & 34 deletions ingest/scripts/prepare_files.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
import json
import logging
import os
import sys
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -101,65 +102,81 @@ def main(
relevant_config = {key: full_config[key] for key in Config.__annotations__}
config = Config(**relevant_config)

metadata = json.load(open(metadata_path, encoding="utf-8"))
to_submit = json.load(open(to_submit_path, encoding="utf-8"))
to_revise = json.load(open(to_revise_path, encoding="utf-8"))
to_revoke = json.load(open(to_revoke_path, encoding="utf-8"))

metadata_submit = []
metadata_revise = []
metadata_submit_prior_to_revoke = [] # Only for multi-segmented case, sequences are revoked
# due to grouping changes and the newly grouped segments must be submitted as new sequences
submit_ids = set()
revise_ids = set()
submit_prior_to_revoke_ids = set()

for fasta_id in to_submit:
metadata_submit.append(metadata[fasta_id])
submit_ids.update(ids_to_add(fasta_id, config))
def write_to_tsv_stream(data, filename, columns_list=None):
# Check if the file exists
file_exists = os.path.exists(filename)

for fasta_id, loculus_accession in to_revise.items():
revise_record = metadata[fasta_id]
revise_record["accession"] = loculus_accession
metadata_revise.append(revise_record)
revise_ids.update(ids_to_add(fasta_id, config))
with open(filename, "a", newline="", encoding="utf-8") as output_file:
keys = columns_list or data.keys()
dict_writer = csv.DictWriter(output_file, keys, delimiter="\t")

found_seq_to_revoke = False
for fasta_id in to_revoke:
metadata_submit_prior_to_revoke.append(metadata[fasta_id])
submit_prior_to_revoke_ids.update(ids_to_add(fasta_id, config))
# Write the header only if the file doesn't already exist
if not file_exists:
dict_writer.writeheader()

if found_seq_to_revoke:
revocation_notification(config, to_revoke)
dict_writer.writerow(data)

def write_to_tsv(data, filename):
if not data:
Path(filename).touch()
return
keys = data[0].keys()
with open(filename, "w", newline="", encoding="utf-8") as output_file:
dict_writer = csv.DictWriter(output_file, keys, delimiter="\t")
dict_writer.writeheader()
dict_writer.writerows(data)
columns_list = None
for field in orjsonl.stream(metadata_path):
fasta_id = field["id"]
record = field["metadata"]
if not columns_list:
columns_list = record.keys()

if fasta_id in to_submit:
write_to_tsv_stream(record, metadata_submit_path, columns_list)
submit_ids.update(ids_to_add(fasta_id, config))
continue

if fasta_id in to_revise:
record["accession"] = to_revise[fasta_id]
write_to_tsv_stream(record, metadata_revise_path, [*columns_list, "accession"])
revise_ids.update(ids_to_add(fasta_id, config))
continue

found_seq_to_revoke = False
if fasta_id in to_revoke:
submit_prior_to_revoke_ids.update(ids_to_add(fasta_id, config))
write_to_tsv_stream(record, metadata_submit_prior_to_revoke_path, columns_list)
found_seq_to_revoke = True

write_to_tsv(metadata_submit, metadata_submit_path)
write_to_tsv(metadata_revise, metadata_revise_path)
write_to_tsv(metadata_submit_prior_to_revoke, metadata_submit_prior_to_revoke_path)
if found_seq_to_revoke:
revocation_notification(config, to_revoke)

def stream_filter_to_fasta(input, output, keep):
def stream_filter_to_fasta(input, output, output_metadata, keep):
if len(keep) == 0:
Path(output).touch()
Path(output_metadata).touch()
return
with open(output, "w", encoding="utf-8") as output_file:
for record in orjsonl.stream(input):
if record["id"] in keep:
output_file.write(f">{record['id']}\n{record['sequence']}\n")

stream_filter_to_fasta(input=sequences_path, output=sequences_submit_path, keep=submit_ids)
stream_filter_to_fasta(input=sequences_path, output=sequences_revise_path, keep=revise_ids)
stream_filter_to_fasta(
input=sequences_path,
output=sequences_submit_path,
output_metadata=metadata_submit_path,
keep=submit_ids,
)
stream_filter_to_fasta(
input=sequences_path,
output=sequences_revise_path,
output_metadata=metadata_revise_path,
keep=revise_ids,
)
stream_filter_to_fasta(
input=sequences_path,
output=sequences_submit_prior_to_revoke_path,
output_metadata=metadata_submit_prior_to_revoke_path,
keep=submit_prior_to_revoke_ids,
)

Expand Down
7 changes: 2 additions & 5 deletions ingest/scripts/prepare_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import json
import logging
from dataclasses import dataclass
from pathlib import Path

import click
import orjsonl
Expand Down Expand Up @@ -143,11 +142,9 @@ def main(

record["hash"] = hashlib.md5(prehash.encode(), usedforsecurity=False).hexdigest()

meta_dict = {rec[fasta_id_field]: rec for rec in metadata}
orjsonl.append(output, {"id": record[fasta_id_field], "metadata": record})

Path(output).write_text(json.dumps(meta_dict, indent=4, sort_keys=True), encoding="utf-8")

logging.info(f"Saved metadata for {len(metadata)} sequences")
logger.info(f"Saved metadata for {len(metadata)} sequences")


if __name__ == "__main__":
Expand Down
Loading
Loading