From dcfafe3d94b9105036ccf30f3dc713798df6530a Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Wed, 12 Feb 2025 17:37:38 +0100 Subject: [PATCH] Keep track of datasets stats - log at validation (#209) --- eole/trainer.py | 4 +-- eole/utils/loss.py | 22 ++++++++++++++--- eole/utils/report_manager.py | 5 ++-- eole/utils/statistics.py | 47 +++++++++++++++++++++--------------- 4 files changed, 52 insertions(+), 26 deletions(-) diff --git a/eole/trainer.py b/eole/trainer.py index 6c819695..e1c7d3e9 100644 --- a/eole/trainer.py +++ b/eole/trainer.py @@ -456,8 +456,8 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats, repor src = batch["src"] src_len = batch["srclen"] if src_len is not None: - report_stats.n_src_words += src_len.sum().item() - total_stats.n_src_words += src_len.sum().item() + report_stats.n_src_tokens += src_len.sum().item() + total_stats.n_src_tokens += src_len.sum().item() tgt = batch["tgt"] kwargs = {} if "images" in batch.keys(): diff --git a/eole/utils/loss.py b/eole/utils/loss.py index 59cdabaf..a6adabca 100644 --- a/eole/utils/loss.py +++ b/eole/utils/loss.py @@ -322,11 +322,19 @@ def forward(self, batch, output, attns, estim=None): estimloss = torch.tensor([0.0], device=loss.device) n_sents = len(batch["srclen"]) - stats = self._stats(n_sents, loss.sum().item(), estimloss.item(), scores, flat_tgt) + stats = self._stats( + n_sents, + loss.sum().item(), + estimloss.item(), + scores, + flat_tgt, + batch["cid"], + batch["cid_line_number"], + ) return loss, stats, estimloss - def _stats(self, bsz, loss, auxloss, scores, target): + def _stats(self, bsz, loss, auxloss, scores, target, cids, cids_idx): """ Args: loss (int): the loss computed by the loss criterion. @@ -346,11 +354,19 @@ def _stats(self, bsz, loss, auxloss, scores, target): n_batchs = 1 if bsz else 0 # in the case criterion reduction is None then we need # to sum the loss of each sentence in the batch + data = {} + for cid, idx in zip(cids, cids_idx): + if cid not in data.keys(): + data[cid] = {"count": 1, "index": idx} + else: + data[cid]["count"] += 1 + data[cid]["index"] = min(idx, data[cid]["index"]) return eole.utils.Statistics( loss=loss, auxloss=auxloss, n_batchs=n_batchs, n_sents=bsz, - n_words=num_non_padding, + n_tokens=num_non_padding, n_correct=num_correct, + data_stats=data, ) diff --git a/eole/utils/report_manager.py b/eole/utils/report_manager.py index 060a8bc3..bb9daeef 100644 --- a/eole/utils/report_manager.py +++ b/eole/utils/report_manager.py @@ -129,11 +129,12 @@ def _report_step(self, lr, patience, step, valid_stats=None, train_stats=None): self.log("Train perplexity: %g" % train_stats.ppl()) self.log("Train accuracy: %g" % train_stats.accuracy()) self.log("Sentences processed: %g" % train_stats.n_sents) + self.log(train_stats.data_stats) self.log( "Average bsz: %4.0f/%4.0f/%2.0f" % ( - train_stats.n_src_words / train_stats.n_batchs, - train_stats.n_words / train_stats.n_batchs, + train_stats.n_src_tokens / train_stats.n_batchs, + train_stats.n_tokens / train_stats.n_batchs, train_stats.n_sents / train_stats.n_batchs, ) ) diff --git a/eole/utils/statistics.py b/eole/utils/statistics.py index 3c479c0d..2d4f0c45 100644 --- a/eole/utils/statistics.py +++ b/eole/utils/statistics.py @@ -23,18 +23,20 @@ def __init__( auxloss=0, n_batchs=0, n_sents=0, - n_words=0, + n_tokens=0, n_correct=0, - computed_metrics={}, + computed_metrics=None, + data_stats=None, ): self.loss = loss self.auxloss = auxloss self.n_batchs = n_batchs self.n_sents = n_sents - self.n_words = n_words + self.n_tokens = n_tokens self.n_correct = n_correct - self.n_src_words = 0 - self.computed_metrics = computed_metrics + self.n_src_tokens = 0 + self.computed_metrics = computed_metrics if computed_metrics is not None else {} + self.data_stats = data_stats if data_stats is not None else {} self.start_time = time.time() @staticmethod @@ -78,16 +80,16 @@ def all_gather_stats_list(stat_list, max_size=4096): if other_rank == our_rank: continue for i, stat in enumerate(stats): - our_stats[i].update(stat, update_n_src_words=True) + our_stats[i].update(stat, update_n_src_tokens=True) return our_stats - def update(self, stat, update_n_src_words=False): + def update(self, stat, update_n_src_tokens=False): """ Update statistics by suming values with another `Statistics` object Args: stat: another statistic object - update_n_src_words(bool): whether to update (sum) `n_src_words` + update_n_src_tokens(bool): whether to update (sum) `n_src_tokens` or not """ @@ -95,12 +97,19 @@ def update(self, stat, update_n_src_words=False): self.auxloss += stat.auxloss self.n_batchs += stat.n_batchs self.n_sents += stat.n_sents - self.n_words += stat.n_words + self.n_tokens += stat.n_tokens self.n_correct += stat.n_correct self.computed_metrics = stat.computed_metrics + for cid in stat.data_stats.keys(): + if cid in self.data_stats.keys(): + self.data_stats[cid]["count"] += stat.data_stats[cid]["count"] + else: + self.data_stats[cid] = {} + self.data_stats[cid]["count"] = stat.data_stats[cid]["count"] + self.data_stats[cid]["index"] = stat.data_stats[cid]["index"] - if update_n_src_words: - self.n_src_words += stat.n_src_words + if update_n_src_tokens: + self.n_src_tokens += stat.n_src_tokens def computed_metric(self, metric): """check if metric(TER/BLEU) is computed and return it""" @@ -109,18 +118,18 @@ def computed_metric(self, metric): def accuracy(self): """compute accuracy""" - return 100 * (self.n_correct / self.n_words) + return 100 * (self.n_correct / self.n_tokens) def xent(self): """compute cross entropy""" - return self.loss / self.n_words + return self.loss / self.n_tokens def aux_loss(self): return self.auxloss / self.n_sents def ppl(self): """compute perplexity""" - return math.exp(min(self.loss / self.n_words, 100)) + return math.exp(min(self.loss / self.n_tokens, 100)) def elapsed_time(self): """compute elapsed time""" @@ -152,11 +161,11 @@ def output(self, step, num_steps, learning_rate, start): self.aux_loss(), learning_rate, self.n_sents, - self.n_src_words / self.n_batchs, - self.n_words / self.n_batchs, + self.n_src_tokens / self.n_batchs, + self.n_tokens / self.n_batchs, self.n_sents / self.n_batchs, - self.n_src_words / (t + 1e-5), - self.n_words / (t + 1e-5), + self.n_src_tokens / (t + 1e-5), + self.n_tokens / (t + 1e-5), time.time() - start, ) + "".join([" {}: {}".format(k, round(v, 2)) for k, v in self.computed_metrics.items()]) @@ -171,7 +180,7 @@ def log_tensorboard(self, prefix, writer, learning_rate, patience, step): for k, v in self.computed_metrics.items(): writer.add_scalar(prefix + "/" + k, round(v, 4), step) writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) - writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) + writer.add_scalar(prefix + "/tgtper", self.n_tokens / t, step) writer.add_scalar(prefix + "/lr", learning_rate, step) if patience is not None: writer.add_scalar(prefix + "/patience", patience, step)