Skip to content

Commit

Permalink
remove verbosity at validation/scoring (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored Jan 15, 2025
1 parent 0ac626a commit 79a10be
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions eole/utils/scoring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,27 @@ def translate(self, model, gpu_rank, step):
# Translator #
# ########## #

# Build translator from options
model_config = self.config.model
model_config._validate_model_config()

# This is somewhat broken and we shall remove or improve
# (take 'inference' field of config if exists?)
# Set "default" translation options on empty cfgfile
predict_config = PredictConfig(model_path=["dummy"], src="dummy")
predict_config.compute_dtype = self.config.training.compute_dtype
if predict_config.transforms_configs.prefix.tgt_prefix != "":
predict_config.tgt_file_prefix = True
predict_config.beam_size = 1 # prevent OOM when GPU is almost full at training
predict_config._validate_predict_config()
# Build translator from options
self.config.training.num_workers = 0
predict_config = PredictConfig(
model_path=["dummy"],
src=self.config.data["valid"].path_src,
compute_dtype=self.config.training.compute_dtype,
beam_size=1,
transforms=self.config.transforms,
transforms_configs=self.config.transforms_configs,
model=model_config,
tgt_file_prefix=self.config.transforms_configs.prefix.tgt_prefix != "",
gpu_ranks=[gpu_rank],
)

scorer = GNMTGlobalScorer.from_config(predict_config)
model_config = self.config.model
model_config._validate_model_config()
translator = Translator.from_config( # we need to review opt/config stuff in translator
model,
self.vocabs,
Expand All @@ -76,11 +84,6 @@ def translate(self, model, gpu_rank, step):
# ################### #

# Reinstantiate the validation iterator
self.config.training.num_workers = 0
predict_config.src = self.config.data["valid"].path_src
predict_config.transforms = self.config.transforms
predict_config.transforms_configs = self.config.transforms_configs
predict_config.model = model_config
# Retrieve raw references and sources
with codecs.open(self.config.data["valid"].path_tgt, "r", encoding="utf-8") as f:
raw_refs = [line.strip("\n") for line in f if line.strip("\n")]
Expand Down

0 comments on commit 79a10be

Please sign in to comment.