Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions silnlp/common/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@

class SilNlpEnv:
def __init__(self):
atexit.register(self.delete_path)
atexit.register(check_transfers)
self.root_dir = Path.home() / ".silnlp"
self.assets_dir = Path(__file__).parent.parent / "assets"
self.temp_model_dir: Optional[Path] = None

self.path_to_delete: Optional[Path] = None
self.set_data_dir()

def set_data_dir(self, data_dir: Optional[Path] = None):
Expand Down Expand Up @@ -109,15 +109,12 @@ def resolve_data_dir(self) -> Path:

raise FileExistsError("No valid path exists")

def get_temp_model_dir(self) -> Path:
if not self.temp_model_dir:
self.temp_model_dir = Path(tempfile.mkdtemp(prefix="silnlp_model_"))
atexit.register(self.delete_temp_model_dir)
return self.temp_model_dir
def delete_path_on_exit(self, path: Union[str, Path]) -> None:
self.path_to_delete = pathify(path)

def delete_temp_model_dir(self) -> None:
if self.temp_model_dir and self.temp_model_dir.is_dir():
shutil.rmtree(self.temp_model_dir)
def delete_path(self) -> None:
if self.path_to_delete and self.path_to_delete.is_dir():
shutil.rmtree(self.path_to_delete)


def check_transfers() -> None:
Expand Down
4 changes: 0 additions & 4 deletions silnlp/nmt/clearml_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class SILClearML:
experiment_suffix: str = ""
clearml_project_folder: str = ""
commit: Optional[str] = None
use_default_model_dir: bool = True
tag: Optional[str] = None
skip_config: bool = False

Expand Down Expand Up @@ -89,7 +88,6 @@ def __post_init__(self) -> None:
if self.commit:
self.task.set_script(commit=self.commit)
if self.queue_name.lower() not in ("local", "locally"):
SIL_NLP_ENV.delete_temp_model_dir()
self.task.execute_remotely(queue_name=self.queue_name)
except LoginError as e:
if self.queue_name is None:
Expand Down Expand Up @@ -131,7 +129,6 @@ def _load_config(self) -> None:
config = yaml.safe_load(file)
if config is None or len(config.keys()) == 0:
raise RuntimeError("Config file has no contents.")
config["use_default_model_dir"] = self.use_default_model_dir
self.config = create_config(exp_dir, config)
return
# There is a ClearML task - lets' do more complex importing.
Expand All @@ -158,5 +155,4 @@ def _load_config(self) -> None:
exp_dir.mkdir(parents=True, exist_ok=True)
with (exp_dir / "config.yml").open("w+", encoding="utf-8") as file:
yaml.safe_dump(data=config, stream=file)
config["use_default_model_dir"] = self.use_default_model_dir
self.config = create_config(exp_dir, config)
3 changes: 1 addition & 2 deletions silnlp/nmt/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
from .hugging_face_config import HuggingFaceConfig


def load_config(exp_name: str, use_default_model_dir: bool = True) -> Config:
def load_config(exp_name: str) -> Config:
exp_dir = get_mt_exp_dir(exp_name)
config_path = exp_dir / "config.yml"

with config_path.open("r", encoding="utf-8") as file:
config: dict = yaml.safe_load(file)
config["use_default_model_dir"] = use_default_model_dir
return create_config(exp_dir, config)


Expand Down
11 changes: 3 additions & 8 deletions silnlp/nmt/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class SILExperiment:
mixed_precision: bool = True
num_devices: int = 1
clearml_queue: Optional[str] = None
save_checkpoints: bool = False
run_prep: bool = False
run_train: bool = False
run_test: bool = False
Expand All @@ -40,7 +39,6 @@ def __post_init__(self):
self.name,
self.clearml_queue,
commit=self.commit,
use_default_model_dir=self.save_checkpoints,
tag=self.clearml_tag,
)
self.name: str = self.clearml.name
Expand Down Expand Up @@ -89,7 +87,6 @@ def test(self):
scorers=self.scorers,
produce_multiple_translations=self.produce_multiple_translations,
save_confidences=self.save_confidences,
use_default_model_dir=self.save_checkpoints,
)

def translate(self):
Expand All @@ -104,7 +101,6 @@ def translate(self):
translator = TranslationTask(
name=self.name,
checkpoint=checkpoint,
use_default_model_dir=self.save_checkpoints,
commit=self.commit,
)

Expand Down Expand Up @@ -253,9 +249,6 @@ def main() -> None:
args.train = True
args.test = True

if not args.train:
args.save_checkpoints = True

exp = SILExperiment(
name=args.experiment,
make_stats=args.stats,
Expand All @@ -265,7 +258,6 @@ def main() -> None:
clearml_queue=args.clearml_queue,
clearml_tag=args.clearml_tag,
commit=args.commit,
save_checkpoints=args.save_checkpoints,
run_prep=args.preprocess,
run_train=args.train,
run_test=args.test,
Expand All @@ -275,6 +267,9 @@ def main() -> None:
scorers=set(s.lower() for s in args.scorers),
score_by_book=args.score_by_book,
)

if not args.save_checkpoints:
SIL_NLP_ENV.delete_path_on_exit(get_mt_exp_dir(args.experiment) / "run")
exp.run()


Expand Down
3 changes: 1 addition & 2 deletions silnlp/nmt/hugging_face_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def get_parent_model_name(parent_exp: str) -> str:

class HuggingFaceConfig(Config):
def __init__(self, exp_dir: Path, config: dict) -> None:
ckpt_dir = str(exp_dir / "run") if config["use_default_model_dir"] else SIL_NLP_ENV.get_temp_model_dir()
config = merge_dict(
{
"data": {
Expand All @@ -340,7 +339,7 @@ def __init__(self, exp_dir: Path, config: dict) -> None:
"auto_grad_acc": False,
"max_steps": 5000,
"group_by_length": True,
"output_dir": ckpt_dir,
"output_dir": str(exp_dir / "run"),
"delete_checkpoint_optimizer_state": True,
"delete_checkpoint_tokenizer": True,
"log_level": "info",
Expand Down
4 changes: 1 addition & 3 deletions silnlp/nmt/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,10 +664,9 @@ def test(
by_book: bool = False,
produce_multiple_translations: bool = False,
save_confidences: bool = False,
use_default_model_dir: bool = True,
):
exp_name = experiment
config = load_config(exp_name, use_default_model_dir)
config = load_config(exp_name)

if not any(config.exp_dir.glob("test*.src.txt")):
LOGGER.info("No test dataset.")
Expand Down Expand Up @@ -885,7 +884,6 @@ def main() -> None:
by_book=args.by_book,
produce_multiple_translations=args.multiple_translations,
save_confidences=args.save_confidences,
use_default_model_dir=True,
)


Expand Down
2 changes: 0 additions & 2 deletions silnlp/nmt/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __exit__(
class TranslationTask:
name: str
checkpoint: Union[str, int] = "last"
use_default_model_dir: bool = True
clearml_queue: Optional[str] = None
commit: Optional[str] = None
clearml_tag: Optional[str] = None
Expand Down Expand Up @@ -284,7 +283,6 @@ def _init_translation_task(self, experiment_suffix: str) -> Tuple[Translator, Co
project_suffix="_infer",
experiment_suffix=experiment_suffix,
commit=self.commit,
use_default_model_dir=self.use_default_model_dir,
tag=self.clearml_tag,
)
self.name = clearml.name
Expand Down