import os
import sys
from chainer.training import extension
from chainer.training.extensions import log_report as log_report_module
from chainer.training.extensions import util
[docs]class PrintReport(extension.Extension):
"""Trainer extension to print the accumulated results.
This extension uses the log accumulated by a :class:`LogReport` extension
to print specified entries of the log in a human-readable format.
Args:
entries (list of str): List of keys of observations to print.
log_report (str or LogReport): Log report to accumulate the
observations. This is either the name of a LogReport extensions
registered to the trainer, or a LogReport instance to use
internally.
out: Stream to print the bar. Standard output is used by default.
"""
def __init__(self, entries, log_report='LogReport', out=sys.stdout):
self._entries = entries
self._log_report = log_report
self._out = out
self._log_len = 0 # number of observations already printed
# format information
entry_widths = [max(10, len(s)) for s in entries]
header = ' '.join(('{:%d}' % w for w in entry_widths)).format(
*entries) + '\n'
self._header = header # printed at the first call
templates = []
for entry, w in zip(entries, entry_widths):
templates.append((entry, '{:<%dg} ' % w, ' ' * (w + 2)))
self._templates = templates
def __call__(self, trainer):
out = self._out
if self._header:
out.write(self._header)
self._header = None
log_report = self._log_report
if isinstance(log_report, str):
log_report = trainer.get_extension(log_report)
elif isinstance(log_report, log_report_module.LogReport):
log_report(trainer) # update the log report
else:
raise TypeError('log report has a wrong type %s' %
type(log_report))
log = log_report.log
log_len = self._log_len
while len(log) > log_len:
# delete the printed contents from the current cursor
if os.name == 'nt':
util.erase_console(0, 0)
else:
out.write('\033[J')
self._print(log[log_len])
log_len += 1
self._log_len = log_len
def serialize(self, serializer):
log_report = self._log_report
if isinstance(log_report, log_report_module.LogReport):
log_report.serialize(serializer['_log_report'])
def _print(self, observation):
out = self._out
for entry, template, empty in self._templates:
if entry in observation:
out.write(template.format(observation[entry]))
else:
out.write(empty)
out.write('\n')