Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 53 additions & 62 deletions RecoTracker/LSTCore/standalone/analysis/DNN/embed_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -353,45 +353,38 @@
" eta1 = F[:, 0] * eta_max\n",
" phi1 = np.arctan2(F[:, 2], F[:, 1])\n",
"\n",
" sim_pairs, dis_pairs = [], []\n",
"\n",
" # similar pairs (same sim-index)\n",
" buckets = {}\n",
" for idx, s in enumerate(S):\n",
" if s != invalid_sim:\n",
" buckets.setdefault(s, []).append(idx)\n",
"\n",
" for lst in buckets.values():\n",
" if len(lst) < 2:\n",
" continue\n",
" for a in range(len(lst) - 1):\n",
" i = lst[a]\n",
" for b in range(a + 1, len(lst)):\n",
" j = lst[b]\n",
" dphi = _delta_phi(phi1[i], phi1[j])\n",
" dr2 = (eta1[i] - eta1[j])**2 + dphi**2\n",
" if dr2 < DELTA_R2_CUT:\n",
" sim_pairs.append((i, j))\n",
"\n",
" # dissimilar pairs (different sim)\n",
" for i in range(n - 1):\n",
" si, ei, pi = S[i], eta1[i], phi1[i]\n",
" for j in range(i + 1, n):\n",
" # skip fake-fake pairs\n",
" if si == invalid_sim and S[j] == invalid_sim:\n",
" continue\n",
" if (si == S[j]) and si != invalid_sim:\n",
" continue\n",
" dphi = _delta_phi(pi, phi1[j])\n",
" dr2 = (ei - eta1[j])**2 + dphi**2\n",
" if dr2 < DELTA_R2_CUT:\n",
" dis_pairs.append((i, j))\n",
" # upper-triangle (non-diagonal) indices\n",
" idx_l, idx_r = np.triu_indices(n, k=1)\n",
" idxs_triu = np.stack((idx_l, idx_r), axis=-1)\n",
"\n",
" # sim indices for each pair\n",
" simidx_l = S[idx_l]\n",
" simidx_r = S[idx_r]\n",
"\n",
" # calculate DR2\n",
" eta_l = eta1[idx_l]\n",
" eta_r = eta1[idx_r]\n",
" phi_l = phi1[idx_l]\n",
" phi_r = phi1[idx_r]\n",
" dphi = np.abs(phi_l - phi_r)\n",
" dphi[dphi > np.pi] -= 2 * np.pi # adjust to [-pi, pi]\n",
" dr2 = (eta_l - eta_r)**2 + dphi**2\n",
"\n",
" # make masks\n",
" dr2_valid = (dr2 < DELTA_R2_CUT)\n",
" sim_idx_same = (simidx_l == simidx_r)\n",
" sim_mask = dr2_valid & sim_idx_same & (simidx_l != invalid_sim)\n",
" dis_mask = dr2_valid & ~sim_idx_same\n",
"\n",
" # get pairs from masks\n",
" sim_pairs = idxs_triu[sim_mask]\n",
" dis_pairs = idxs_triu[dis_mask]\n",
"\n",
" # down-sample\n",
" if len(sim_pairs) > max_sim:\n",
" sim_pairs = random.sample(sim_pairs, max_sim)\n",
" sim_pairs = sim_pairs[random.sample(range(len(sim_pairs)), max_sim)]\n",
" if len(dis_pairs) > max_dis:\n",
" dis_pairs = random.sample(dis_pairs, max_dis)\n",
" dis_pairs = dis_pairs[random.sample(range(len(dis_pairs)), max_dis)]\n",
"\n",
" print(f\"[evt {evt_idx:4d}] T5s={n:5d} sim={len(sim_pairs):3d} dis={len(dis_pairs):3d}\")\n",
" return evt_idx, sim_pairs, dis_pairs\n",
Expand Down Expand Up @@ -530,35 +523,33 @@
" eta_t = F_T5[:,0] * eta_max\n",
" phi_t = np.arctan2(F_T5[:,2], F_T5[:,1])\n",
"\n",
" # bucket T5 by sim-idx for similar\n",
" buckets = {}\n",
" for j,s in enumerate(S_T5):\n",
" if s != invalid_sim:\n",
" buckets.setdefault(s, []).append(j)\n",
" for i,s in enumerate(S_pLS):\n",
" if s == invalid_sim:\n",
" continue\n",
" for j in buckets.get(s, []):\n",
" dphi = (phi_p[i] - phi_t[j] + np.pi) % (2*np.pi) - np.pi\n",
" dr2 = (eta_p[i] - eta_t[j])**2 + dphi**2\n",
" if dr2 < DELTA_R2_CUT_PLS_T5:\n",
" sim_pairs.append((i,j))\n",
"\n",
" # find dissimilar (different sim-idx) pairs\n",
" for i in range(n_p):\n",
" for j in range(n_t):\n",
" if S_pLS[i] == S_T5[j] and S_pLS[i] != invalid_sim:\n",
" continue\n",
" dphi = (phi_p[i] - phi_t[j] + np.pi) % (2*np.pi) - np.pi\n",
" dr2 = (eta_p[i] - eta_t[j])**2 + dphi**2\n",
" if dr2 < DELTA_R2_CUT_PLS_T5:\n",
" dis_pairs.append((i,j))\n",
"\n",
" # down-sample to limits\n",
" # make all possible pairs (i, j)\n",
" idx_p, idx_t = np.indices( (n_p, n_t) )\n",
" idx_p, idx_t = idx_p.flatten(), idx_t.flatten()\n",
"\n",
" # calculate angles\n",
" dphi = (phi_p[idx_p] - phi_t[idx_t] + np.pi) % (2 * np.pi) - np.pi\n",
" dr2 = (eta_p[idx_p] - eta_t[idx_t])**2 + dphi**2\n",
" dr2_valid = (dr2 < DELTA_R2_CUT_PLS_T5)\n",
"\n",
" # compare sim indices\n",
" simidx_p = S_pLS[idx_p]\n",
" simidx_t = S_T5[idx_t]\n",
" sim_idx_same = (simidx_p == simidx_t)\n",
"\n",
" # create masks for similar and dissimilar pairs\n",
" sim_mask = dr2_valid & sim_idx_same & (simidx_p != invalid_sim)\n",
" dis_mask = dr2_valid & ~sim_idx_same\n",
"\n",
" # get the pairs\n",
" sim_pairs = np.column_stack((idx_p[sim_mask], idx_t[sim_mask]))\n",
" dis_pairs = np.column_stack((idx_p[dis_mask], idx_t[dis_mask]))\n",
"\n",
" # down-sample\n",
" if len(sim_pairs) > max_sim:\n",
" sim_pairs = random.sample(sim_pairs, max_sim)\n",
" sim_pairs = sim_pairs[random.sample(range(len(sim_pairs)), max_sim)]\n",
" if len(dis_pairs) > max_dis:\n",
" dis_pairs = random.sample(dis_pairs, max_dis)\n",
" dis_pairs = dis_pairs[random.sample(range(len(dis_pairs)), max_dis)]\n",
"\n",
" # print per-event summary\n",
" print(f\"[evt {evt_idx:4d}] pLSs={n_p:5d} T5s={n_t:5d} \"\n",
Expand Down