diff --git a/RecoTracker/LSTCore/standalone/analysis/DNN/embed_train.ipynb b/RecoTracker/LSTCore/standalone/analysis/DNN/embed_train.ipynb index 9c50228a53531..6372812513e66 100644 --- a/RecoTracker/LSTCore/standalone/analysis/DNN/embed_train.ipynb +++ b/RecoTracker/LSTCore/standalone/analysis/DNN/embed_train.ipynb @@ -353,45 +353,38 @@ " eta1 = F[:, 0] * eta_max\n", " phi1 = np.arctan2(F[:, 2], F[:, 1])\n", "\n", - " sim_pairs, dis_pairs = [], []\n", - "\n", - " # similar pairs (same sim-index)\n", - " buckets = {}\n", - " for idx, s in enumerate(S):\n", - " if s != invalid_sim:\n", - " buckets.setdefault(s, []).append(idx)\n", - "\n", - " for lst in buckets.values():\n", - " if len(lst) < 2:\n", - " continue\n", - " for a in range(len(lst) - 1):\n", - " i = lst[a]\n", - " for b in range(a + 1, len(lst)):\n", - " j = lst[b]\n", - " dphi = _delta_phi(phi1[i], phi1[j])\n", - " dr2 = (eta1[i] - eta1[j])**2 + dphi**2\n", - " if dr2 < DELTA_R2_CUT:\n", - " sim_pairs.append((i, j))\n", - "\n", - " # dissimilar pairs (different sim)\n", - " for i in range(n - 1):\n", - " si, ei, pi = S[i], eta1[i], phi1[i]\n", - " for j in range(i + 1, n):\n", - " # skip fake-fake pairs\n", - " if si == invalid_sim and S[j] == invalid_sim:\n", - " continue\n", - " if (si == S[j]) and si != invalid_sim:\n", - " continue\n", - " dphi = _delta_phi(pi, phi1[j])\n", - " dr2 = (ei - eta1[j])**2 + dphi**2\n", - " if dr2 < DELTA_R2_CUT:\n", - " dis_pairs.append((i, j))\n", + " # upper-triangle (non-diagonal) indices\n", + " idx_l, idx_r = np.triu_indices(n, k=1)\n", + " idxs_triu = np.stack((idx_l, idx_r), axis=-1)\n", + "\n", + " # sim indices for each pair\n", + " simidx_l = S[idx_l]\n", + " simidx_r = S[idx_r]\n", + "\n", + " # calculate DR2\n", + " eta_l = eta1[idx_l]\n", + " eta_r = eta1[idx_r]\n", + " phi_l = phi1[idx_l]\n", + " phi_r = phi1[idx_r]\n", + " dphi = np.abs(phi_l - phi_r)\n", + " dphi[dphi > np.pi] -= 2 * np.pi # adjust to [-pi, pi]\n", + " dr2 = (eta_l - eta_r)**2 + dphi**2\n", + "\n", + " # make masks\n", + " dr2_valid = (dr2 < DELTA_R2_CUT)\n", + " sim_idx_same = (simidx_l == simidx_r)\n", + " sim_mask = dr2_valid & sim_idx_same & (simidx_l != invalid_sim)\n", + " dis_mask = dr2_valid & ~sim_idx_same\n", + "\n", + " # get pairs from masks\n", + " sim_pairs = idxs_triu[sim_mask]\n", + " dis_pairs = idxs_triu[dis_mask]\n", "\n", " # down-sample\n", " if len(sim_pairs) > max_sim:\n", - " sim_pairs = random.sample(sim_pairs, max_sim)\n", + " sim_pairs = sim_pairs[random.sample(range(len(sim_pairs)), max_sim)]\n", " if len(dis_pairs) > max_dis:\n", - " dis_pairs = random.sample(dis_pairs, max_dis)\n", + " dis_pairs = dis_pairs[random.sample(range(len(dis_pairs)), max_dis)]\n", "\n", " print(f\"[evt {evt_idx:4d}] T5s={n:5d} sim={len(sim_pairs):3d} dis={len(dis_pairs):3d}\")\n", " return evt_idx, sim_pairs, dis_pairs\n", @@ -530,35 +523,33 @@ " eta_t = F_T5[:,0] * eta_max\n", " phi_t = np.arctan2(F_T5[:,2], F_T5[:,1])\n", "\n", - " # bucket T5 by sim-idx for similar\n", - " buckets = {}\n", - " for j,s in enumerate(S_T5):\n", - " if s != invalid_sim:\n", - " buckets.setdefault(s, []).append(j)\n", - " for i,s in enumerate(S_pLS):\n", - " if s == invalid_sim:\n", - " continue\n", - " for j in buckets.get(s, []):\n", - " dphi = (phi_p[i] - phi_t[j] + np.pi) % (2*np.pi) - np.pi\n", - " dr2 = (eta_p[i] - eta_t[j])**2 + dphi**2\n", - " if dr2 < DELTA_R2_CUT_PLS_T5:\n", - " sim_pairs.append((i,j))\n", - "\n", - " # find dissimilar (different sim-idx) pairs\n", - " for i in range(n_p):\n", - " for j in range(n_t):\n", - " if S_pLS[i] == S_T5[j] and S_pLS[i] != invalid_sim:\n", - " continue\n", - " dphi = (phi_p[i] - phi_t[j] + np.pi) % (2*np.pi) - np.pi\n", - " dr2 = (eta_p[i] - eta_t[j])**2 + dphi**2\n", - " if dr2 < DELTA_R2_CUT_PLS_T5:\n", - " dis_pairs.append((i,j))\n", - "\n", - " # down-sample to limits\n", + " # make all possible pairs (i, j)\n", + " idx_p, idx_t = np.indices( (n_p, n_t) )\n", + " idx_p, idx_t = idx_p.flatten(), idx_t.flatten()\n", + "\n", + " # calculate angles\n", + " dphi = (phi_p[idx_p] - phi_t[idx_t] + np.pi) % (2 * np.pi) - np.pi\n", + " dr2 = (eta_p[idx_p] - eta_t[idx_t])**2 + dphi**2\n", + " dr2_valid = (dr2 < DELTA_R2_CUT_PLS_T5)\n", + "\n", + " # compare sim indices\n", + " simidx_p = S_pLS[idx_p]\n", + " simidx_t = S_T5[idx_t]\n", + " sim_idx_same = (simidx_p == simidx_t)\n", + "\n", + " # create masks for similar and dissimilar pairs\n", + " sim_mask = dr2_valid & sim_idx_same & (simidx_p != invalid_sim)\n", + " dis_mask = dr2_valid & ~sim_idx_same\n", + "\n", + " # get the pairs\n", + " sim_pairs = np.column_stack((idx_p[sim_mask], idx_t[sim_mask]))\n", + " dis_pairs = np.column_stack((idx_p[dis_mask], idx_t[dis_mask]))\n", + "\n", + " # down-sample\n", " if len(sim_pairs) > max_sim:\n", - " sim_pairs = random.sample(sim_pairs, max_sim)\n", + " sim_pairs = sim_pairs[random.sample(range(len(sim_pairs)), max_sim)]\n", " if len(dis_pairs) > max_dis:\n", - " dis_pairs = random.sample(dis_pairs, max_dis)\n", + " dis_pairs = dis_pairs[random.sample(range(len(dis_pairs)), max_dis)]\n", "\n", " # print per-event summary\n", " print(f\"[evt {evt_idx:4d}] pLSs={n_p:5d} T5s={n_t:5d} \"\n",