Skip to content

Commit

Permalink
cleanup forecast; closes #34
Browse files Browse the repository at this point in the history
Signed-off-by: ivelin <[email protected]>
  • Loading branch information
ivelin committed Mar 29, 2024
1 parent 3e26fa2 commit 160a503
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ tmp/**
.vscode
canswim_model.pt*
build/**
dist/**
dist/**
data/
forecast/
28 changes: 23 additions & 5 deletions src/canswim/hfhub.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from darts.models.forecasting.forecasting_model import ForecastingModel
from huggingface_hub import snapshot_download, upload_folder, create_repo
import torch
import canswim
import tarfile
import os.path


class HFHub:
Expand Down Expand Up @@ -154,16 +155,22 @@ def download_data(self, repo_id: str = None, local_dir: str = None):
data_dir = self.data_dir
if repo_id is None:
repo_id = self.repo_id
logger.info(
f"Downloading hf data from {repo_id} to data dir:\n",
os.listdir(data_dir),
)
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=data_dir,
token=self.HF_TOKEN,
)
logger.info(
f"Downloaded hf dataset files from {repo_id} to data dir:\n",
os.listdir(data_dir),
)
# Unpack forecast parquet files from tar
forecast_dir = f"{data_dir}/forecast/"
forecast_tar = f"{data_dir}/forecast.tar"
with tarfile.open(forecast_tar, "r:gz") as tar:
logger.info(f"Extracting {forecast_tar} to folder {forecast_dir}")
tar.extractall(path=forecast_dir, filter="data")

def upload_data(
self, repo_id: str = None, private: bool = True, local_dir: str = None
Expand Down Expand Up @@ -193,12 +200,23 @@ def upload_data(
logger.info(
f"repo_info: {repo_info}",
)
# Compress forecast parquet files to pass hfhub limitation of 25k LFS files
forecast_dir = f"{data_dir}/forecast/"
forecast_tar = f"{data_dir}/forecast.tar"
with tarfile.open(forecast_tar, "w:gz") as tar:
logger.info(f"Creating {forecast_tar} from folder {forecast_dir}")
tar.add(forecast_dir, arcname=os.path.basename(forecast_dir))
# upload select files to hfhub
logger.info(f"uploading folder {data_dir}")
upload_folder(
repo_id=repo_id,
# path_in_repo="data-3rd-party",
repo_type="dataset",
folder_path=data_dir,
token=self.HF_TOKEN,
# ignore_patterns=[forecast_dir],
# allow_patterns="",
# delete_patterns=[forecast_dir],
)
logger.info(
"Upload finished.",
Expand Down

0 comments on commit 160a503

Please sign in to comment.