Skip to content

Commit

Permalink
Keep track of datasets stats - log at validation (#209)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored Feb 12, 2025
1 parent 452d370 commit dcfafe3
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 26 deletions.
4 changes: 2 additions & 2 deletions eole/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,8 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats, repor
src = batch["src"]
src_len = batch["srclen"]
if src_len is not None:
report_stats.n_src_words += src_len.sum().item()
total_stats.n_src_words += src_len.sum().item()
report_stats.n_src_tokens += src_len.sum().item()
total_stats.n_src_tokens += src_len.sum().item()
tgt = batch["tgt"]
kwargs = {}
if "images" in batch.keys():
Expand Down
22 changes: 19 additions & 3 deletions eole/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,19 @@ def forward(self, batch, output, attns, estim=None):
estimloss = torch.tensor([0.0], device=loss.device)
n_sents = len(batch["srclen"])

stats = self._stats(n_sents, loss.sum().item(), estimloss.item(), scores, flat_tgt)
stats = self._stats(
n_sents,
loss.sum().item(),
estimloss.item(),
scores,
flat_tgt,
batch["cid"],
batch["cid_line_number"],
)

return loss, stats, estimloss

def _stats(self, bsz, loss, auxloss, scores, target):
def _stats(self, bsz, loss, auxloss, scores, target, cids, cids_idx):
"""
Args:
loss (int): the loss computed by the loss criterion.
Expand All @@ -346,11 +354,19 @@ def _stats(self, bsz, loss, auxloss, scores, target):
n_batchs = 1 if bsz else 0
# in the case criterion reduction is None then we need
# to sum the loss of each sentence in the batch
data = {}
for cid, idx in zip(cids, cids_idx):
if cid not in data.keys():
data[cid] = {"count": 1, "index": idx}
else:
data[cid]["count"] += 1
data[cid]["index"] = min(idx, data[cid]["index"])
return eole.utils.Statistics(
loss=loss,
auxloss=auxloss,
n_batchs=n_batchs,
n_sents=bsz,
n_words=num_non_padding,
n_tokens=num_non_padding,
n_correct=num_correct,
data_stats=data,
)
5 changes: 3 additions & 2 deletions eole/utils/report_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,12 @@ def _report_step(self, lr, patience, step, valid_stats=None, train_stats=None):
self.log("Train perplexity: %g" % train_stats.ppl())
self.log("Train accuracy: %g" % train_stats.accuracy())
self.log("Sentences processed: %g" % train_stats.n_sents)
self.log(train_stats.data_stats)
self.log(
"Average bsz: %4.0f/%4.0f/%2.0f"
% (
train_stats.n_src_words / train_stats.n_batchs,
train_stats.n_words / train_stats.n_batchs,
train_stats.n_src_tokens / train_stats.n_batchs,
train_stats.n_tokens / train_stats.n_batchs,
train_stats.n_sents / train_stats.n_batchs,
)
)
Expand Down
47 changes: 28 additions & 19 deletions eole/utils/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@ def __init__(
auxloss=0,
n_batchs=0,
n_sents=0,
n_words=0,
n_tokens=0,
n_correct=0,
computed_metrics={},
computed_metrics=None,
data_stats=None,
):
self.loss = loss
self.auxloss = auxloss
self.n_batchs = n_batchs
self.n_sents = n_sents
self.n_words = n_words
self.n_tokens = n_tokens
self.n_correct = n_correct
self.n_src_words = 0
self.computed_metrics = computed_metrics
self.n_src_tokens = 0
self.computed_metrics = computed_metrics if computed_metrics is not None else {}
self.data_stats = data_stats if data_stats is not None else {}
self.start_time = time.time()

@staticmethod
Expand Down Expand Up @@ -78,29 +80,36 @@ def all_gather_stats_list(stat_list, max_size=4096):
if other_rank == our_rank:
continue
for i, stat in enumerate(stats):
our_stats[i].update(stat, update_n_src_words=True)
our_stats[i].update(stat, update_n_src_tokens=True)
return our_stats

def update(self, stat, update_n_src_words=False):
def update(self, stat, update_n_src_tokens=False):
"""
Update statistics by suming values with another `Statistics` object
Args:
stat: another statistic object
update_n_src_words(bool): whether to update (sum) `n_src_words`
update_n_src_tokens(bool): whether to update (sum) `n_src_tokens`
or not
"""
self.loss += stat.loss
self.auxloss += stat.auxloss
self.n_batchs += stat.n_batchs
self.n_sents += stat.n_sents
self.n_words += stat.n_words
self.n_tokens += stat.n_tokens
self.n_correct += stat.n_correct
self.computed_metrics = stat.computed_metrics
for cid in stat.data_stats.keys():
if cid in self.data_stats.keys():
self.data_stats[cid]["count"] += stat.data_stats[cid]["count"]
else:
self.data_stats[cid] = {}
self.data_stats[cid]["count"] = stat.data_stats[cid]["count"]
self.data_stats[cid]["index"] = stat.data_stats[cid]["index"]

if update_n_src_words:
self.n_src_words += stat.n_src_words
if update_n_src_tokens:
self.n_src_tokens += stat.n_src_tokens

def computed_metric(self, metric):
"""check if metric(TER/BLEU) is computed and return it"""
Expand All @@ -109,18 +118,18 @@ def computed_metric(self, metric):

def accuracy(self):
"""compute accuracy"""
return 100 * (self.n_correct / self.n_words)
return 100 * (self.n_correct / self.n_tokens)

def xent(self):
"""compute cross entropy"""
return self.loss / self.n_words
return self.loss / self.n_tokens

def aux_loss(self):
return self.auxloss / self.n_sents

def ppl(self):
"""compute perplexity"""
return math.exp(min(self.loss / self.n_words, 100))
return math.exp(min(self.loss / self.n_tokens, 100))

def elapsed_time(self):
"""compute elapsed time"""
Expand Down Expand Up @@ -152,11 +161,11 @@ def output(self, step, num_steps, learning_rate, start):
self.aux_loss(),
learning_rate,
self.n_sents,
self.n_src_words / self.n_batchs,
self.n_words / self.n_batchs,
self.n_src_tokens / self.n_batchs,
self.n_tokens / self.n_batchs,
self.n_sents / self.n_batchs,
self.n_src_words / (t + 1e-5),
self.n_words / (t + 1e-5),
self.n_src_tokens / (t + 1e-5),
self.n_tokens / (t + 1e-5),
time.time() - start,
)
+ "".join([" {}: {}".format(k, round(v, 2)) for k, v in self.computed_metrics.items()])
Expand All @@ -171,7 +180,7 @@ def log_tensorboard(self, prefix, writer, learning_rate, patience, step):
for k, v in self.computed_metrics.items():
writer.add_scalar(prefix + "/" + k, round(v, 4), step)
writer.add_scalar(prefix + "/accuracy", self.accuracy(), step)
writer.add_scalar(prefix + "/tgtper", self.n_words / t, step)
writer.add_scalar(prefix + "/tgtper", self.n_tokens / t, step)
writer.add_scalar(prefix + "/lr", learning_rate, step)
if patience is not None:
writer.add_scalar(prefix + "/patience", patience, step)

0 comments on commit dcfafe3

Please sign in to comment.