Skip to content

Commit

Permalink
Update and rename nn-related/single_plane.py to single_plane.py
Browse files Browse the repository at this point in the history
  • Loading branch information
haardie authored Feb 4, 2025
1 parent f455353 commit f663055
Showing 1 changed file with 28 additions and 13 deletions.
41 changes: 28 additions & 13 deletions nn-related/single_plane.py → single_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import json
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

#
sys.path.append("./src")
Expand Down Expand Up @@ -56,6 +57,7 @@
num_epochs = config["training"]["num_epochs"]
patience = config["training"]["patience"]

sgn = "pos"
# ==================================#
# ======== SET UP DIRECTORIES ======#
# ==================================#
Expand Down Expand Up @@ -99,9 +101,19 @@
val_df = pd.DataFrame(columns=["ground_truth", "output"])
# ==================================#
# model = cls.ModifiedResNet()
# ==================================#
model = cls.ModifiedMobileNetV3()

## SMALLER MODELS: =================
# -> mobilenets
# model = cls.ModifiedMobileNetV3()
# model = cls.ModifiedMobileNetV2()
# -> efficientnets
# ==================================

#torch set number of threads to number of cores
torch.set_num_threads(4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = cls.ModifiedEfficientNet(dropout=0.1, device=device)

# Distribute model across all available GPUs
if torch.cuda.device_count() > 1:
Expand Down Expand Up @@ -164,26 +176,27 @@
) # evt dirs of structure: metadata in csv format & plane dirs -> npz file
background_dir = os.path.join(data_dir, "atmonu-npz")

# event_dirs = signal_dirs.result() + background_dirs.result()
signal_decay_dirs = [
os.path.join("/mnt/lustre/helios-shared/GAMS/dune/pdk-root/pdk_decays", dir)
for dir in os.listdir("/mnt/lustre/helios-shared/GAMS/dune/pdk-root/pdk_decays")[:2]
for dir in os.listdir("/mnt/lustre/helios-shared/GAMS/dune/pdk-root/pdk_decays")
if os.path.isdir(
os.path.join("/mnt/lustre/helios-shared/GAMS/dune/pdk-root/pdk_decays", dir)
)
]
background_decay_dirs = [
os.path.join("/mnt/lustre/helios-shared/GAMS/dune/pdk-root/atmonu_decays", dir)
for dir in os.listdir("/mnt/lustre/helios-shared/GAMS/dune/pdk-root/atmonu_decays")[
:2
]
for dir in os.listdir("/mnt/lustre/helios-shared/GAMS/dune/pdk-root/atmonu_decays")
if os.path.isdir(
os.path.join("/mnt/lustre/helios-shared/GAMS/dune/pdk-root/atmonu_decays", dir)
)
]


decay_dirs = signal_decay_dirs + background_decay_dirs
dataset = cls.SparseMatrixDatasetMeta(decay_dirs=decay_dirs, plane_idx=plane)
events = (evt for decay_dir in decay_dirs for evt in Path(decay_dir).iterdir() if evt.is_dir())
dataset = cls.SparseMatrixDataset(event_paths=events, plane_idx=plane)

# print(dataset[:5])

print("==================")
print("Splitting dataset.")
Expand All @@ -199,17 +212,18 @@

dataloaders = {
"train": DataLoader(
train, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
train, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=True
),
"test": DataLoader(
test, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
test, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=True
),
"val": DataLoader(
val, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
val, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=True
),
}
# plot decay mode distribution
# fns.eval_decay_distrib(train, val, test, dataset, output_dir=".")
print (f'lenghts. Dataset: {len(dataset)}, Train: {len(train)}, Val: {len(val)}, Test: {len(test)}')
print(f'Dataloaders: Train: {len(dataloaders["train"])}, Val: {len(dataloaders["val"])}, Test: {len(dataloaders["test"])}')


# ==================================#
# =========== TRAINING =============#
Expand Down Expand Up @@ -274,4 +288,5 @@
val_acc,
)


run.finish()

0 comments on commit f663055

Please sign in to comment.