Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .config/nextest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# match all tests in the project
[[profile.default.overrides]]
filter = "test(/.*/)"
threads-required = 4
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ Cargo.lock
**/*.rs.bk

.github/copilot-instructions.md
**/CLAUDE.md
live-sync.sh
3 changes: 3 additions & 0 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ pub enum Commands {

#[command(arg_required_else_help = true)]
Extract(ExtractArgs),

/// Display version information, build details, and project URLs
Version,
}

#[derive(Debug, Args)]
Expand Down
19 changes: 6 additions & 13 deletions src/consensus/formatter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@ use std::str;

/// A struct to format the header of a consensus sequence
pub struct HeaderFormatter<'a> {
group: &'a ArchivedDuplicateGroup<'a>,
args: &'a ConsensusArgs,
key: String,
group_size: usize,
id: usize,
start_time: Instant,
consensus_type: String,
Expand All @@ -25,38 +23,33 @@ impl<'a> HeaderFormatter<'a> {
let id = group.id;

let key = match group.key {
ArchivedDuplicateGroupKey::Normal(id) => id.to_string(),
_ => "todo".to_string(),
ArchivedDuplicateGroupKey::Valid(id) => id.to_string(),
ArchivedDuplicateGroupKey::Filtered(id, _) => id.to_string(),
};

let start_time = Instant::now();
let consensus_type = match group.key {
ArchivedDuplicateGroupKey::Normal(_) => {
ArchivedDuplicateGroupKey::Valid(_) => {
if group_size == 1 {
"single".to_string()
} else {
"consensus".to_string()
}
}
ArchivedDuplicateGroupKey::Invalid(_) => "ignored".to_string(),
ArchivedDuplicateGroupKey::Filtered(_) => "filtered".to_string(),
ArchivedDuplicateGroupKey::Filtered(_, _) => "filtered".to_string(),
};

// placeholder
let orig_headers = vec![Vec::new()];

let s = Self {
group,
Self {
args,
key,
group_size,
id,
start_time,
consensus_type,
orig_headers,
};

s
}
}

pub fn add_read(&mut self, cluster_id: usize, read: &SequenceRecord) {
Expand Down
41 changes: 26 additions & 15 deletions src/consensus/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,29 @@ use std::io::{Cursor, Write as IoWrite};

use rayon::prelude::*;
use spoa::{AlignmentEngine, AlignmentType};
use std::cell::RefCell;

mod cluster;
mod formatter;

// CN: perf-optimization; thread-local AlignmentEngine to avoid creating new engine per consensus_call
thread_local! {
static ALIGNMENT_ENGINE: RefCell<AlignmentEngine> = RefCell::new(
AlignmentEngine::new(AlignmentType::kOV, 5, -4, -8, -6, -10, -4)
);
}

/// Helper function to get access to the thread-local AlignmentEngine
fn with_alignment_engine<F, R>(f: F) -> R
where
F: FnOnce(&mut AlignmentEngine) -> R,
{
ALIGNMENT_ENGINE.with(|engine| {
let mut engine = engine.borrow_mut();
f(&mut *engine)
})
}

/// Generate consensus sequences from duplicate read groups
pub fn consensus(args: &crate::cli::ConsensusArgs) -> Result<()> {
let paths = FileIndexPath::new(&args.input);
Expand Down Expand Up @@ -164,16 +183,13 @@ fn consensus_call(
"@{}\n{}\n+\n{}",
header,
str::from_utf8(&seq)?,
str::from_utf8(&qual)?
str::from_utf8(qual)?
)?;

Ok((result, 1, false))
} else {
// consensus call

// initialise `spoa` machinery
let mut alignment_engine = AlignmentEngine::new(AlignmentType::kOV, 5, -4, -8, -6, -10, -4);

let mut read_idx = 0usize;

let mut graphs = vec![spoa::Graph::new()];
Expand All @@ -192,7 +208,7 @@ fn consensus_call(
let mut inserted_cluster_id = 0;
for (cluster_id, graph) in graphs.iter_mut().enumerate() {
// Align to the graph
let align = alignment_engine.align_from_bytes(&seq, graph);
let align = with_alignment_engine(|engine| engine.align_from_bytes(&seq, graph));

let will_cluster = if first_read_in_group || args.no_clustering {
true
Expand All @@ -207,13 +223,7 @@ fn consensus_call(
};

if will_cluster {
let alignment_result = graph.add_alignment_from_bytes(&align, &seq, &qual);

// debug check for bugs in the alignment prediction algorithm
// todo: fix this?
// if let Some(alignment_prediction) = alignment_predictions.last() {
// assert_eq!(alignment_prediction, alignment_result)
// }
let alignment_result = graph.add_alignment_from_bytes(&align, &seq, qual);

// add each read in the duplicate group to the graph
inserted_cluster_id = cluster_id;
Expand All @@ -229,8 +239,9 @@ fn consensus_call(
// do we need to add a new graph, because this read didn't cluster?
if !did_cluster {
let mut new_graph = spoa::Graph::new();
let align = alignment_engine.align_from_bytes(&seq, &new_graph);
alignment_predictions.push(new_graph.add_alignment_from_bytes(&align, &seq, &qual));
let align =
with_alignment_engine(|engine| engine.align_from_bytes(&seq, &new_graph));
alignment_predictions.push(new_graph.add_alignment_from_bytes(&align, &seq, qual));

debug!("Added new graph");
graphs.push(new_graph);
Expand All @@ -253,7 +264,7 @@ fn consensus_call(
"@{}\n{}\n+\n{}",
header,
str::from_utf8(&seq).unwrap(),
str::from_utf8(&qual).unwrap()
str::from_utf8(qual).unwrap()
)?;
}

Expand Down
2 changes: 1 addition & 1 deletion src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub fn extract(args: &crate::cli::ExtractArgs) -> anyhow::Result<()> {
index
.groups()
.filter_map(|group| {
if let ArchivedDuplicateGroupKey::Normal(k) = group.key {
if let ArchivedDuplicateGroupKey::Valid(k) = group.key {
if re.is_match(&k.0) {
return Some(group.id);
}
Expand Down
152 changes: 70 additions & 82 deletions src/io/index/construct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ use humansize::{format_size, FormatSizeOptions};
use needletail::parser::SequenceRecord;
use needletail::FastxReader;
use regex::Regex;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Seek};
use std::io::{BufRead, BufReader, Seek};
use std::path::Path;
use thiserror::Error;

Expand Down Expand Up @@ -53,7 +54,7 @@ pub fn construct_index(cli: &crate::cli::IndexArgs) -> Result<()> {
let mut last_report = 0;

// callback function to process each read that comes in & add to the index
let callback = |loc: ReadLocation, key: DuplicateGroupKey, seq: SequenceRecord| {
let callback = |loc: ReadLocation, id: RecordIdentifier, seq: SequenceRecord| {
let pos = loc.pos();

// should we report our current progress?
Expand All @@ -67,18 +68,18 @@ pub fn construct_index(cli: &crate::cli::IndexArgs) -> Result<()> {
);
}

let key = if !key.is_invalid() && !should_keep(&seq, &filters) {
DuplicateGroupKey::Filtered(pos)
let key = if !should_keep(&seq, &filters) {
DuplicateGroupKey::Filtered(id, pos)
} else {
key
DuplicateGroupKey::Valid(id)
};

index.add_read(key, loc, seq);
Ok(())
};

if let Some(loc) = &cli.clusters {
todo!();
if let Some(cluster_file) = &cli.clusters {
iter_lines_with_cluster_file(&mut reader, cluster_file, callback)?;
} else {
let re = match &cli.barcode_regex {
Some(v) => {
Expand Down Expand Up @@ -112,7 +113,7 @@ fn iter_lines_with_regex<F>(
mut callback: F,
) -> Result<()>
where
F: FnMut(ReadLocation, DuplicateGroupKey, SequenceRecord) -> Result<()>,
F: FnMut(ReadLocation, RecordIdentifier, SequenceRecord) -> Result<()>,
{
// expected_len is used to ensure that every read has the same format
let mut expected_len: Option<usize> = None;
Expand All @@ -130,102 +131,89 @@ where
};
let header = std::str::from_utf8(rec.id())?;

let bc = extract_header_id(header, re, read_location.pos());
let barcode_location = match bc {
Ok((len, id)) => {
// check # of barcode groups is the same
let expected_len = *expected_len.get_or_insert(len);
if expected_len != len {
bail!(IndexGenerationErr::DifferentMatchCounts {
header: header.to_string(),
re: re.clone(),
pos: read_location.pos(),
count: len,
expected: expected_len
})
}

DuplicateGroupKey::Normal(id)
}
Err(_e) => DuplicateGroupKey::Invalid(read_location.pos()),
};
let (len, id) = extract_header_id(header, re, read_location.pos())?;

// check # of barcode groups is the same
let expected_len = *expected_len.get_or_insert(len);
if expected_len != len {
bail!(IndexGenerationErr::DifferentMatchCounts {
header: header.to_string(),
re: re.clone(),
pos: read_location.pos(),
count: len,
expected: expected_len
})
}

callback(read_location, barcode_location, rec)?;
callback(read_location, id, rec)?;
}
Ok(())
}

fn iter_lines_with_cluster_file(
reader: BufReader<File>,
// wtr: &mut IndexWriter,
fn iter_lines_with_cluster_file<F>(
reader: &mut BufReader<File>,
cluster_file: &Path,
skip_invalid_ids: bool,
) -> Result<()> {
todo!();
/*
todo: filepath
let mut cluster_rdr = csv::ReaderBuilder::new()
.delimiter(b';')
.has_headers(false)
.from_path(filepath)?;

// first, we will read the clusters file
info!("Reading identifiers from clusters file...");

let mut cluster_map = std::collections::HashMap::new();

for result in clusters.records() {
let record = result?;

let read_id = record[0].to_string();
let identifier = match record.len() {
// in this case, there is just one identifier (no BC and UMI) so we read the first
// column directly as the 'identifier'
2 => record[1].to_string(),

// in this case, there are two identifiers (i.e. BC and UMI) so we combine them to
// produce an 'identifier'
3 => format!("{}_{}", &record[1], &record[2]),

// doesn't make sense
_ => bail!(InvalidClusterRow {
row: record.as_slice().to_string()
}),
};
mut callback: F,
) -> Result<()>
where
F: FnMut(ReadLocation, RecordIdentifier, SequenceRecord) -> Result<()>,
{
// Read cluster file line by line
info!("Reading identifiers from file {}", cluster_file.display());

let cluster_f = File::open(cluster_file)?;
let cluster_reader = BufReader::new(cluster_f);
let mut cluster_map = HashMap::new();

for line in cluster_reader.lines() {
let line = line?;
let line = line.trim();
if line.is_empty() {
continue;
}

let parts: Vec<&str> = line.split(';').collect();
if parts.len() != 2 {
bail!(IndexGenerationErr::InvalidClusterRow {
row: line.to_string()
});
}

let read_id = parts[0].to_string();
let identifier = RecordIdentifier::from_recs(&[parts[1]]);
cluster_map.insert(read_id, identifier);
}

info!("Finished reading clusters. ");
info!(
"Finished reading clusters. Found {} cluster mappings",
cluster_map.len()
);

let mut fastq_reader = needletail::parser::FastqReader::new(reader);

while let Some(rec) = fastq_reader.next() {
let rec = rec.expect("Invalid record");
let pos = rec.position().byte() as usize;
let bytes_len = rec.all().len() + 1;

match cluster_map.get(&rec.id) {}
let Some(identifier) = cluster_map.get(&rec.id) else {
if !skip_invalid_ids {
bail!(RowNotInClusters { header: rec.id })
}
wtr.metadata.unmatched_read_count += 1;
continue;
let read_location = ReadLocation {
_pos: rec.position().byte(),
_byte_len: (rec.all().len() as u32) + 1,
seq_len: rec.num_bases() as u32,
qual: rec.phred_quality_avg().unwrap_or_default(),
};
wtr.metadata.matched_read_count += 1;

rec.id = identifier.clone();
wtr.add_record(&rec, position, file_len, ignored)?;
let header = std::str::from_utf8(rec.id())?;

total_quality += rec.phred_quality_total();
total_len += rec.len();
}
let Some(id) = cluster_map.get(header) else {
bail!(IndexGenerationErr::RowNotInClusters {
header: header.to_string()
})
};
let id = id.clone();

wtr.write_size((fastq_reader.position().byte() as f64) / (1024u32.pow(3) as f64));
callback(read_location, id, rec)?;
}

Ok(())
*/
}

/// Extract identifier components from a read header using a regex pattern
Expand Down
Loading