Skip to content

Commit 6982af4

Browse files
Removed temp model dir, reworked save-checkpoints (#876)
1 parent d8e2dfb commit 6982af4

File tree

7 files changed

+13
-31
lines changed

7 files changed

+13
-31
lines changed

silnlp/common/environment.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525

2626
class SilNlpEnv:
2727
def __init__(self):
28+
atexit.register(self.delete_path)
2829
atexit.register(check_transfers)
2930
self.root_dir = Path.home() / ".silnlp"
3031
self.assets_dir = Path(__file__).parent.parent / "assets"
31-
self.temp_model_dir: Optional[Path] = None
32-
32+
self.path_to_delete: Optional[Path] = None
3333
self.set_data_dir()
3434

3535
def set_data_dir(self, data_dir: Optional[Path] = None):
@@ -109,15 +109,12 @@ def resolve_data_dir(self) -> Path:
109109

110110
raise FileExistsError("No valid path exists")
111111

112-
def get_temp_model_dir(self) -> Path:
113-
if not self.temp_model_dir:
114-
self.temp_model_dir = Path(tempfile.mkdtemp(prefix="silnlp_model_"))
115-
atexit.register(self.delete_temp_model_dir)
116-
return self.temp_model_dir
112+
def delete_path_on_exit(self, path: Union[str, Path]) -> None:
113+
self.path_to_delete = pathify(path)
117114

118-
def delete_temp_model_dir(self) -> None:
119-
if self.temp_model_dir and self.temp_model_dir.is_dir():
120-
shutil.rmtree(self.temp_model_dir)
115+
def delete_path(self) -> None:
116+
if self.path_to_delete and self.path_to_delete.is_dir():
117+
shutil.rmtree(self.path_to_delete)
121118

122119

123120
def check_transfers() -> None:

silnlp/nmt/clearml_connection.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ class SILClearML:
2323
experiment_suffix: str = ""
2424
clearml_project_folder: str = ""
2525
commit: Optional[str] = None
26-
use_default_model_dir: bool = True
2726
tag: Optional[str] = None
2827
skip_config: bool = False
2928

@@ -89,7 +88,6 @@ def __post_init__(self) -> None:
8988
if self.commit:
9089
self.task.set_script(commit=self.commit)
9190
if self.queue_name.lower() not in ("local", "locally"):
92-
SIL_NLP_ENV.delete_temp_model_dir()
9391
self.task.execute_remotely(queue_name=self.queue_name)
9492
except LoginError as e:
9593
if self.queue_name is None:
@@ -131,7 +129,6 @@ def _load_config(self) -> None:
131129
config = yaml.safe_load(file)
132130
if config is None or len(config.keys()) == 0:
133131
raise RuntimeError("Config file has no contents.")
134-
config["use_default_model_dir"] = self.use_default_model_dir
135132
self.config = create_config(exp_dir, config)
136133
return
137134
# There is a ClearML task - lets' do more complex importing.
@@ -158,5 +155,4 @@ def _load_config(self) -> None:
158155
exp_dir.mkdir(parents=True, exist_ok=True)
159156
with (exp_dir / "config.yml").open("w+", encoding="utf-8") as file:
160157
yaml.safe_dump(data=config, stream=file)
161-
config["use_default_model_dir"] = self.use_default_model_dir
162158
self.config = create_config(exp_dir, config)

silnlp/nmt/config_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
from .hugging_face_config import HuggingFaceConfig
88

99

10-
def load_config(exp_name: str, use_default_model_dir: bool = True) -> Config:
10+
def load_config(exp_name: str) -> Config:
1111
exp_dir = get_mt_exp_dir(exp_name)
1212
config_path = exp_dir / "config.yml"
1313

1414
with config_path.open("r", encoding="utf-8") as file:
1515
config: dict = yaml.safe_load(file)
16-
config["use_default_model_dir"] = use_default_model_dir
1716
return create_config(exp_dir, config)
1817

1918

silnlp/nmt/experiment.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ class SILExperiment:
2323
mixed_precision: bool = True
2424
num_devices: int = 1
2525
clearml_queue: Optional[str] = None
26-
save_checkpoints: bool = False
2726
run_prep: bool = False
2827
run_train: bool = False
2928
run_test: bool = False
@@ -40,7 +39,6 @@ def __post_init__(self):
4039
self.name,
4140
self.clearml_queue,
4241
commit=self.commit,
43-
use_default_model_dir=self.save_checkpoints,
4442
tag=self.clearml_tag,
4543
)
4644
self.name: str = self.clearml.name
@@ -89,7 +87,6 @@ def test(self):
8987
scorers=self.scorers,
9088
produce_multiple_translations=self.produce_multiple_translations,
9189
save_confidences=self.save_confidences,
92-
use_default_model_dir=self.save_checkpoints,
9390
)
9491

9592
def translate(self):
@@ -104,7 +101,6 @@ def translate(self):
104101
translator = TranslationTask(
105102
name=self.name,
106103
checkpoint=checkpoint,
107-
use_default_model_dir=self.save_checkpoints,
108104
commit=self.commit,
109105
)
110106

@@ -253,9 +249,6 @@ def main() -> None:
253249
args.train = True
254250
args.test = True
255251

256-
if not args.train:
257-
args.save_checkpoints = True
258-
259252
exp = SILExperiment(
260253
name=args.experiment,
261254
make_stats=args.stats,
@@ -265,7 +258,6 @@ def main() -> None:
265258
clearml_queue=args.clearml_queue,
266259
clearml_tag=args.clearml_tag,
267260
commit=args.commit,
268-
save_checkpoints=args.save_checkpoints,
269261
run_prep=args.preprocess,
270262
run_train=args.train,
271263
run_test=args.test,
@@ -275,6 +267,9 @@ def main() -> None:
275267
scorers=set(s.lower() for s in args.scorers),
276268
score_by_book=args.score_by_book,
277269
)
270+
271+
if not args.save_checkpoints:
272+
SIL_NLP_ENV.delete_path_on_exit(get_mt_exp_dir(args.experiment) / "run")
278273
exp.run()
279274

280275

silnlp/nmt/hugging_face_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@ def get_parent_model_name(parent_exp: str) -> str:
314314

315315
class HuggingFaceConfig(Config):
316316
def __init__(self, exp_dir: Path, config: dict) -> None:
317-
ckpt_dir = str(exp_dir / "run") if config["use_default_model_dir"] else SIL_NLP_ENV.get_temp_model_dir()
318317
config = merge_dict(
319318
{
320319
"data": {
@@ -340,7 +339,7 @@ def __init__(self, exp_dir: Path, config: dict) -> None:
340339
"auto_grad_acc": False,
341340
"max_steps": 5000,
342341
"group_by_length": True,
343-
"output_dir": ckpt_dir,
342+
"output_dir": str(exp_dir / "run"),
344343
"delete_checkpoint_optimizer_state": True,
345344
"delete_checkpoint_tokenizer": True,
346345
"log_level": "info",

silnlp/nmt/test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -664,10 +664,9 @@ def test(
664664
by_book: bool = False,
665665
produce_multiple_translations: bool = False,
666666
save_confidences: bool = False,
667-
use_default_model_dir: bool = True,
668667
):
669668
exp_name = experiment
670-
config = load_config(exp_name, use_default_model_dir)
669+
config = load_config(exp_name)
671670

672671
if not any(config.exp_dir.glob("test*.src.txt")):
673672
LOGGER.info("No test dataset.")
@@ -885,7 +884,6 @@ def main() -> None:
885884
by_book=args.by_book,
886885
produce_multiple_translations=args.multiple_translations,
887886
save_confidences=args.save_confidences,
888-
use_default_model_dir=True,
889887
)
890888

891889

silnlp/nmt/translate.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def __exit__(
4747
class TranslationTask:
4848
name: str
4949
checkpoint: Union[str, int] = "last"
50-
use_default_model_dir: bool = True
5150
clearml_queue: Optional[str] = None
5251
commit: Optional[str] = None
5352
clearml_tag: Optional[str] = None
@@ -284,7 +283,6 @@ def _init_translation_task(self, experiment_suffix: str) -> Tuple[Translator, Co
284283
project_suffix="_infer",
285284
experiment_suffix=experiment_suffix,
286285
commit=self.commit,
287-
use_default_model_dir=self.use_default_model_dir,
288286
tag=self.clearml_tag,
289287
)
290288
self.name = clearml.name

0 commit comments

Comments
 (0)