Skip to content

Commit

Permalink
Formatted as a Python library
Browse files Browse the repository at this point in the history
  • Loading branch information
GiorgioMorales committed Mar 21, 2024
1 parent 9d336ea commit 5c7642e
Show file tree
Hide file tree
Showing 23 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
/StatisticsandPlots.py
/src/DualAQD/models/temp_weights/
/src/PredictionIntervals/models/temp_weights/
6 changes: 3 additions & 3 deletions PIGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import itertools
import numpy as np
import matplotlib.pyplot as plt
from src.DualAQD.models.NNModel import NNModel
from src.PredictionIntervals.models.NNModel import NNModel
from sklearn.model_selection import KFold
from src.DualAQD.Datasets.GenerateDatasets import DataLoader
from src.PredictionIntervals.Datasets.GenerateDatasets import DataLoader
from sklearn.model_selection import train_test_split
# Functions needed for QD+
from src.DualAQD.models.aggregation_functions import _split_normal_aggregator # You can comment it if you only want to test DualAQD
from src.PredictionIntervals.models.aggregation_functions import _split_normal_aggregator # You can comment it if you only want to test DualAQD


class PIGenerator:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle
from src.utils import *
from sklearn.model_selection import KFold
from src.DualAQD.models.NNModel import NNModel
from src.PredictionIntervals.models.NNModel import NNModel


class Trainer:
Expand Down Expand Up @@ -46,9 +46,9 @@ def train(self, batch_size=32, epochs=500, eta_=0.01, printProcess=True, normDat
"""
# If the temp folder does not exist, create it
root = get_project_root()
folder = os.path.join(root, "src//DualAQD//models//temp_weights")
if not os.path.exists(os.path.join(root, "src//DualAQD//models//temp_weights")):
os.mkdir(os.path.join(root, "src//DualAQD//models//temp_weights"))
folder = os.path.join(root, "src//PredictionIntervals//models//temp_weights")
if not os.path.exists(os.path.join(root, "src//PredictionIntervals//models//temp_weights")):
os.mkdir(os.path.join(root, "src//PredictionIntervals//models//temp_weights"))
if not os.path.exists(folder):
os.mkdir(folder)

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from src import utils
from tqdm import trange
from torch import optim
from src.DualAQD.models.network import *
from src.PredictionIntervals.models.network import *


# import matplotlib.pyplot as plt
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 5c7642e

Please sign in to comment.