""" Statistics calculation utility """
import time
import math
import sys
from onmt.utils.logging import logger
[docs]class Statistics(object):
"""
Accumulator for loss statistics.
Currently calculates:
* accuracy
* perplexity
* elapsed time
"""
def __init__(
self, loss=0, n_batchs=0, n_sents=0, n_words=0, n_correct=0, computed_metrics={}
):
self.loss = loss
self.n_batchs = n_batchs
self.n_sents = n_sents
self.n_words = n_words
self.n_correct = n_correct
self.n_src_words = 0
self.computed_metrics = computed_metrics
self.start_time = time.time()
[docs] @staticmethod
def all_gather_stats(stat, max_size=4096):
"""
Gather a `Statistics` object accross multiple process/nodes
Args:
stat(:obj:Statistics): the statistics object to gather
accross all processes/nodes
max_size(int): max buffer size to use
Returns:
`Statistics`, the update stats object
"""
stats = Statistics.all_gather_stats_list([stat], max_size=max_size)
return stats[0]
[docs] @staticmethod
def all_gather_stats_list(stat_list, max_size=4096):
"""
Gather a `Statistics` list accross all processes/nodes
Args:
stat_list(list([`Statistics`])): list of statistics objects to
gather accross all processes/nodes
max_size(int): max buffer size to use
Returns:
our_stats(list([`Statistics`])): list of updated stats
"""
from torch.distributed import get_rank
from onmt.utils.distributed import all_gather_list
# Get a list of world_size lists with len(stat_list) Statistics objects
all_stats = all_gather_list(stat_list, max_size=max_size)
our_rank = get_rank()
our_stats = all_stats[our_rank]
for other_rank, stats in enumerate(all_stats):
if other_rank == our_rank:
continue
for i, stat in enumerate(stats):
our_stats[i].update(stat, update_n_src_words=True)
return our_stats
[docs] def update(self, stat, update_n_src_words=False):
"""
Update statistics by suming values with another `Statistics` object
Args:
stat: another statistic object
update_n_src_words(bool): whether to update (sum) `n_src_words`
or not
"""
self.loss += stat.loss
self.n_batchs += stat.n_batchs
self.n_sents += stat.n_sents
self.n_words += stat.n_words
self.n_correct += stat.n_correct
self.computed_metrics = stat.computed_metrics
if update_n_src_words:
self.n_src_words += stat.n_src_words
[docs] def accuracy(self):
"""compute accuracy"""
return 100 * (self.n_correct / self.n_words)
[docs] def xent(self):
"""compute cross entropy"""
return self.loss / self.n_words
[docs] def ppl(self):
"""compute perplexity"""
return math.exp(min(self.loss / self.n_words, 100))
[docs] def elapsed_time(self):
"""compute elapsed time"""
return time.time() - self.start_time
[docs] def output(self, step, num_steps, learning_rate, start):
"""Write out statistics to stdout.
Args:
step (int): current step
n_batch (int): total batches
start (int): start time of step.
"""
t = self.elapsed_time()
step_fmt = "%2d" % step
if num_steps > 0:
step_fmt = "%s/%5d" % (step_fmt, num_steps)
logger.info(
(
"Step %s; acc: %2.1f; ppl: %5.1f; xent: %2.1f; "
+ "lr: %7.5f; sents: %7.0f; bsz: %4.0f/%4.0f/%2.0f; "
+ "%3.0f/%3.0f tok/s; %6.0f sec;"
)
% (
step_fmt,
self.accuracy(),
self.ppl(),
self.xent(),
learning_rate,
self.n_sents,
self.n_src_words / self.n_batchs,
self.n_words / self.n_batchs,
self.n_sents / self.n_batchs,
self.n_src_words / (t + 1e-5),
self.n_words / (t + 1e-5),
time.time() - start,
)
+ "".join(
[
" {}: {}".format(k, round(v, 2))
for k, v in self.computed_metrics.items()
]
)
)
sys.stdout.flush()
[docs] def log_tensorboard(self, prefix, writer, learning_rate, patience, step):
"""display statistics to tensorboard"""
t = self.elapsed_time()
writer.add_scalar(prefix + "/xent", self.xent(), step)
writer.add_scalar(prefix + "/ppl", self.ppl(), step)
for k, v in self.computed_metrics.items():
writer.add_scalar(prefix + "/" + k, round(v, 4), step)
writer.add_scalar(prefix + "/accuracy", self.accuracy(), step)
writer.add_scalar(prefix + "/tgtper", self.n_words / t, step)
writer.add_scalar(prefix + "/lr", learning_rate, step)
if patience is not None:
writer.add_scalar(prefix + "/patience", patience, step)