try:
import tensorflow as tf
except ImportError:
tf = None
try:
from StringIO import StringIO
except ImportError:
from io import StringIO, BytesIO
try:
import matplotlib.pyplot as plt
except ImportError:
plt = None
import numpy as np
from scipy.misc import toimage
from .base import Logger
from ....utils import torch_utils as tu
from ....utils import python_utils as pyu
from ....utils import train_utils as tru
from ....utils.exceptions import assert_
[docs]class TensorboardLogger(Logger):
"""Class to enable logging of training progress to Tensorboard.
Currently supports logging scalars and images.
"""
# Borrowed from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 and
# https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/04-utils/tensorboard/logger.py
def __init__(self, log_directory=None, log_scalars_every=None, log_images_every=None,
send_image_at_batch_indices='all', send_image_at_channel_indices='all',
send_volume_at_z_indices='mid'):
"""
Parameters
----------
log_directory : str
Path to the directory where the log files will be placed.
log_scalars_every : str or tuple or inferno.utils.train_utils.Frequency
How often scalars should be logged to Tensorboard. By default, once every iteration.
log_images_every : str or tuple or inferno.utils.train_utils.Frequency
How often images should be logged to Tensorboard. By default, once every iteration.
send_image_at_batch_indices : list or str
The indices of the batches to be logged. An `image_batch` usually has the shape
(num_samples, num_channels, num_rows, num_cols). By setting this argument to say
[0, 2], only images corresponding to `image_batch[0]` and `image_batch[2]` are
logged. When a str, it should be 'all', in which case, all samples are logged.
send_image_at_channel_indices : list or str
Similar to `send_image_at_batch_indices`, but applying to channels.
send_volume_at_z_indices : list or str
For 3D batches of shape (num_samples, num_channels, num_z_slices, num_rows, num_cols),
select the indices of the z slices to be logged. When a str, it could be 'all' or
'mid' (to log the central z slice).
Warnings
--------
Leaving log_images_every to the default (i.e. once every iteration) might generate a
large logfile and/or slow down the training.
"""
assert tf is not None
assert plt is not None
super(TensorboardLogger, self).__init__(log_directory=log_directory)
self._log_scalars_every = None
self._log_images_every = None
self._writer = None
self._config = {'image_batch_indices': send_image_at_batch_indices,
'image_channel_indices': send_image_at_channel_indices,
'volume_z_indices': send_volume_at_z_indices}
# We ought to know the trainer states we're observing (and plotting to tensorboard).
# These are the defaults.
self._trainer_states_being_observed_while_training = {'training_loss',
'training_error',
'training_prediction',
'training_inputs',
'training_target',
'learning_rate'}
self._trainer_states_being_observed_while_validating = {'validation_error_averaged',
'validation_loss_averaged'}
if log_scalars_every is not None:
self.log_scalars_every = log_scalars_every
if log_images_every is not None:
self.log_images_every = log_images_every
@property
def writer(self):
if self._writer is None:
self._writer = tf.summary.FileWriter(self.log_directory)
return self._writer
@property
def log_scalars_every(self):
if self._log_scalars_every is None:
self._log_scalars_every = tru.Frequency(1, 'iterations')
return self._log_scalars_every
@log_scalars_every.setter
def log_scalars_every(self, value):
self._log_scalars_every = tru.Frequency.build_from(value)
@property
def log_scalars_now(self):
# Using persistent=True in a property getter is probably not a very good idea...
# We need to make sure that this getter is called only once per callback-call.
return self.log_scalars_every.match(iteration_count=self.trainer.iteration_count,
epoch_count=self.trainer.epoch_count,
persistent=True)
@property
def log_images_every(self):
if self._log_images_every is None:
self._log_images_every = tru.Frequency(1, 'iterations')
return self._log_images_every
@log_images_every.setter
def log_images_every(self, value):
self._log_images_every = tru.Frequency.build_from(value)
@property
def log_images_now(self):
# Using persistent=True in a property getter is probably not a very good idea...
# We need to make sure that this getter is called only once per callback-call.
return self.log_images_every.match(iteration_count=self.trainer.iteration_count,
epoch_count=self.trainer.epoch_count,
persistent=True)
[docs] def observe_state(self, key, observe_while='training'):
# Validate arguments
keyword_mapping = {'train': 'training',
'training': 'training',
'validation': 'validating',
'validating': 'validating'}
observe_while = keyword_mapping.get(observe_while)
assert_(observe_while is not None,
"The keyword observe_while must be one of: {}."
.format(set(keyword_mapping.keys())),
ValueError)
assert_(isinstance(key, str),
"State key must be a string, got {} instead.".format(type(key).__name__),
TypeError)
# Add to set of observed states
if observe_while == 'training':
self._trainer_states_being_observed_while_training.add(key)
elif observe_while == 'validating':
self._trainer_states_being_observed_while_validating.add(key)
else:
raise NotImplementedError
return self
[docs] def observe_training_and_validation_state(self, key):
for mode in ['training', 'validation']:
self.observe_state('{}_{}'.format(mode, key), observe_while=mode)
[docs] def observe_states(self, keys, observe_while='training'):
for key in keys:
self.observe_state(key, observe_while=observe_while)
return self
[docs] def observe_training_and_validation_states(self, keys):
for key in keys:
self.observe_training_and_validation_state(key)
return self
[docs] def log_object(self, tag, object_, allow_scalar_logging=True, allow_image_logging=True):
assert isinstance(tag, str)
if isinstance(object_, (list, tuple)):
for object_num, _object in enumerate(object_):
self.log_object("{}_{}".format(tag, object_num),
_object,
allow_scalar_logging,
allow_image_logging)
return
# Check whether object is a scalar
if tu.is_scalar_tensor(object_) and allow_scalar_logging:
# Log scalar
value = object_.float()[0]
self.log_scalar(tag, value, step=self.trainer.iteration_count)
elif isinstance(object_, (float, int)) and allow_scalar_logging:
value = float(object_)
self.log_scalar(tag, value, step=self.trainer.iteration_count)
elif tu.is_label_image_or_volume_tensor(object_) and allow_image_logging:
# Add a channel axis and log as images
self.log_image_or_volume_batch(tag, object_[:, None, ...],
self.trainer.iteration_count)
elif tu.is_image_or_volume_tensor(object_) and allow_image_logging:
# Log images
self.log_image_or_volume_batch(tag, object_, self.trainer.iteration_count)
else:
# Object is neither a scalar nor an image, there's nothing we can do
pass
[docs] def end_of_training_iteration(self, **_):
log_scalars_now = self.log_scalars_now
log_images_now = self.log_images_now
if not log_scalars_now and not log_images_now:
# Nothing to log, so we won't bother
return
# Read states
for state_key in self._trainer_states_being_observed_while_training:
state = self.trainer.get_state(state_key, default=None)
if state is None:
# State not found in trainer but don't throw a hissy fit
continue
self.log_object(state_key, state,
allow_scalar_logging=log_scalars_now,
allow_image_logging=log_images_now)
[docs] def end_of_validation_run(self, **_):
# Log everything
# Read states
for state_key in self._trainer_states_being_observed_while_validating:
state = self.trainer.get_state(state_key, default=None)
if state is None:
# State not found in trainer but don't throw a hissy fit
continue
self.log_object(state_key, state,
allow_scalar_logging=True,
allow_image_logging=True)
[docs] def log_image_or_volume_batch(self, tag, batch, step=None):
assert pyu.is_maybe_list_of(tu.is_image_or_volume_tensor)(batch)
step = step or self.trainer.iteration_count
image_list = self.extract_images_from_batch(batch)
self.log_images(tag, image_list, step)
[docs] def log_scalar(self, tag, value, step):
"""
Parameter
----------
tag : basestring
Name of the scalar
value
step : int
training iteration
"""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag,
simple_value=value)])
self.writer.add_summary(summary, step)
[docs] def log_images(self, tag, images, step):
"""Logs a list of images."""
image_summaries = []
for image_num, image in enumerate(images):
# Write the image to a string
try:
# Python 2.7
s = StringIO()
toimage(image).save(s, format="png")
except TypeError:
# Python 3.X
s = BytesIO()
toimage(image).save(s, format="png")
# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=image.shape[0],
width=image.shape[1])
# Create a Summary value
image_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, image_num),
image=img_sum))
# Create and write Summary
summary = tf.Summary(value=image_summaries)
self.writer.add_summary(summary, step)
[docs] def log_histogram(self, tag, values, step, bins=1000):
"""Logs the histogram of a list/vector of values."""
# Create histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)
# Fill fields of histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values**2))
# Requires equal number as bins, where the first goes from -DBL_MAX to bin_edges[1]
# See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto#L30
# Thus, we drop the start of the first bin
bin_edges = bin_edges[1:]
# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)
# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()
[docs] def get_config(self):
# Apparently, some SwigPyObject objects cannot be pickled - so we need to build the
# writer on the fly.
config = super(TensorboardLogger, self).get_config()
config.update({'_writer': None})
return config