diff --git a/bin/scripts/index_all_pairwise.py b/bin/scripts/index_all_pairwise.py index f605f09..e6941e9 100644 --- a/bin/scripts/index_all_pairwise.py +++ b/bin/scripts/index_all_pairwise.py @@ -192,7 +192,7 @@ def find_bin_range(entries, target_bin): @numba.njit -def compute_all_pairs(spectra, shared_entries, shifted_entries, tolerance, threshold): +def compute_all_pairs(spectra, shared_entries, shifted_entries, tolerance, threshold, scoring_func): results = List() n_spectra = len(spectra) @@ -241,7 +241,7 @@ def compute_all_pairs(spectra, shared_entries, shifted_entries, tolerance, thres exact_matches = List() for spec_idx, _ in candidates[:TOPPRODUCTS * 2]: target_spec = spectra[spec_idx] - score, shared, shifted = calculate_exact_score_GNPS(spectra[query_idx], target_spec,tolerance) + score, shared, shifted = scoring_func(spectra[query_idx], target_spec,tolerance) if score >= threshold: exact_matches.append((spec_idx, score, shared, shifted)) @@ -251,9 +251,105 @@ def compute_all_pairs(spectra, shared_entries, shifted_entries, tolerance, thres return results +@numba.njit(fastmath=True) +def calculate_exact_score_GNPS(query_spec, target_spec, TOLERANCE): + """Numba-optimized cosine scoring with shift handling""" + q_mz = query_spec[0] + q_int = query_spec[1] + q_pre = query_spec[2] + q_charge = query_spec[3] + + t_mz = target_spec[0] + t_int = target_spec[1] + t_pre = target_spec[2] + + # Calculate precursor mass difference (assuming charge=1) + precursor_mass_diff = (q_pre - t_pre)*q_charge + allow_shift = True + fragment_tol = TOLERANCE + + # Pre-allocate arrays for matches (adjust size as needed) + max_matches = len(q_mz) * 2 # Estimate maximum possible matches + scores_arr = np.zeros(max_matches, dtype=np.float32) + idx_q = np.zeros(max_matches, dtype=np.int32) + idx_t = np.zeros(max_matches, dtype=np.int32) + match_count = 0 + + # For each peak in query spectrum + for q_idx in range(len(q_mz)): + q_mz_val = q_mz[q_idx] + q_int_val = q_int[q_idx] + + # For each possible shift (charge=1) + num_shifts = 1 + if allow_shift and abs(precursor_mass_diff) >= fragment_tol: + num_shifts += 1 + + for shift_idx in range(num_shifts): + if shift_idx == 0: + # No shift + adjusted_mz = q_mz_val + else: + # Apply shift + adjusted_mz = q_mz_val - precursor_mass_diff + + # Find matching peaks in target using binary search + start = np.searchsorted(t_mz, adjusted_mz - fragment_tol) + end = np.searchsorted(t_mz, adjusted_mz + fragment_tol) + + for t_idx in range(start, end): + if match_count >= max_matches: + break # Prevent overflow + if abs(t_mz[t_idx] - adjusted_mz)