Skip to content

Commit

Permalink
Refactor dataset classes to remove unnecessary imports and simplify c…
Browse files Browse the repository at this point in the history
…ode + included changes from comments
  • Loading branch information
Jad-yehya committed Sep 16, 2024
1 parent ebdb0ce commit a780de6
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
11 changes: 6 additions & 5 deletions datasets/msl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from benchopt import BaseDataset, safe_import_context, config

with safe_import_context() as import_ctx:
import pathlib
import numpy as np
import requests

Expand Down Expand Up @@ -35,15 +34,17 @@ class Dataset(BaseDataset):
def get_data(self):
path = config.get_data_path(key="MSL")
# Check if the data is already here
if not pathlib.Path.exists(path):
if not path.exists():
path.mkdir(parents=True, exist_ok=True)

response = requests.get(URL_XTRAIN)
with open(pathlib.Path(path) / "MSL_train.npy", "wb") as f:
with open(path / "MSL_train.npy", "wb") as f:
f.write(response.content)
response = requests.get(URL_XTEST)
with open(pathlib.Path(path) / "MSL_test.npy", "wb") as f:
with open(path / "MSL_test.npy", "wb") as f:
f.write(response.content)
response = requests.get(URL_YTEST)
with open(pathlib.Path(path) / "MSL_test_label.npy", "wb") as f:
with open(path / "MSL_test_label.npy", "wb") as f:
f.write(response.content)

X_train = np.load(path / "MSL_train.npy")
Expand Down
5 changes: 3 additions & 2 deletions datasets/psm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
with safe_import_context() as import_ctx:
import requests
import pandas as pd
import pathlib

URL_XTRAIN = (
"https://drive.google.com/uc?&id=1d3tAbYTj0CZLhB7z3IDTfTRg3E7qj_tw"
Expand All @@ -30,8 +29,10 @@ class Dataset(BaseDataset):
def get_data(self):
# Check if the data is already here
path = config.get_data_path(key="PSM")
# Check if the data is already here
if not path.exists():
path.mkdir(parents=True, exist_ok=True)

if not pathlib.Path.exists(path):
response = requests.get(URL_XTRAIN)
with open(path / "PSM_train.csv", "wb") as f:
f.write(response.content)
Expand Down
10 changes: 5 additions & 5 deletions datasets/smap.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from benchopt import BaseDataset, safe_import_context, config

with safe_import_context() as import_ctx:
import pathlib
import numpy as np
import requests
# from sklearn.model_selection import TimeSeriesSplit
Expand Down Expand Up @@ -36,18 +35,19 @@ def get_data(self):
path = config.get_data_path(key="SMAP")

# Check if the data is already here
if not pathlib.Path.exists(path):
if not path.exists():
path.mkdir(parents=True, exist_ok=True)

response = requests.get(URL_XTRAIN)
with open(pathlib.Path(path) / "SMAP_train.npy", "wb") as f:
with open(path / "SMAP_train.npy", "wb") as f:
f.write(response.content)

response = requests.get(URL_XTEST)
with open(pathlib.Path(path) / "SMAP_test.npy", "wb") as f:
with open(path / "SMAP_test.npy", "wb") as f:
f.write(response.content)

response = requests.get(URL_YTEST)
with open(pathlib.Path(path) / "SMAP_test_label.npy", "wb") as f:
with open(path / "SMAP_test_label.npy", "wb") as f:
f.write(response.content)

X_train = np.load(path / "SMAP_train.npy")
Expand Down

0 comments on commit a780de6

Please sign in to comment.