import numpy
from chainer import cuda
from chainer import serializer
[docs]class DictionarySerializer(serializer.Serializer):
"""Serializer for dictionary.
This is the standard serializer in Chainer. The hierarchy of objects are
simply mapped to a flat dictionary with keys representing the paths to
objects in the hierarchy.
.. note::
Despite of its name, this serializer DOES NOT serialize the
object into external files. It just build a flat dictionary of arrays
that can be fed into :func:`numpy.savez` and
:func:`numpy.savez_compressed`. If you want to use this serializer
directly, you have to manually send a resulting dictionary to one of
these functions.
Args:
target (dict): The dictionary that this serializer saves the objects
to. If target is None, then a new dictionary is created.
path (str): The base path in the hierarchy that this serializer
indicates.
Attributes:
target (dict): The target dictionary. Once the serialization completes,
this dictionary can be fed into :func:`numpy.savez` or
:func:`numpy.savez_compressed` to serialize it in the NPZ format.
"""
def __init__(self, target=None, path=''):
self.target = {} if target is None else target
self.path = path
def __getitem__(self, key):
key = key.strip('/')
return DictionarySerializer(self.target, self.path + key + '/')
def __call__(self, key, value):
key = key.lstrip('/')
ret = value
if isinstance(value, cuda.ndarray):
value = value.get()
arr = numpy.asarray(value)
self.target[self.path + key] = arr
return ret
[docs]def save_npz(filename, obj, compression=True):
"""Saves an object to the file in NPZ format.
This is a short-cut function to save only one object into an NPZ file.
Args:
filename (str): Target file name.
obj: Object to be serialized. It must support serialization protocol.
compression (bool): If ``True``, compression in the resulting zip file
is enabled.
"""
s = DictionarySerializer()
s.save(obj)
with open(filename, 'wb') as f:
if compression:
numpy.savez_compressed(f, **s.target)
else:
numpy.savez(f, **s.target)
[docs]class NpzDeserializer(serializer.Deserializer):
"""Deserializer for NPZ format.
This is the standard deserializer in Chainer. This deserializer can be used
to read an object serialized by :func:`save_npz`.
Args:
npz: `npz` file object.
path: The base path that the deserialization starts from.
strict (bool): If ``True``, the deserializer raises an error when an
expected value is not found in the given NPZ file. Otherwise,
it ignores the value and skip deserialization.
"""
def __init__(self, npz, path='', strict=True):
self.npz = npz
self.path = path
self.strict = strict
def __getitem__(self, key):
key = key.strip('/')
return NpzDeserializer(
self.npz, self.path + key + '/', strict=self.strict)
def __call__(self, key, value):
key = self.path + key.lstrip('/')
if not self.strict and key not in self.npz:
return value
dataset = self.npz[key]
if value is None:
return dataset
elif isinstance(value, numpy.ndarray):
numpy.copyto(value, dataset)
elif isinstance(value, cuda.ndarray):
value.set(numpy.asarray(dataset))
else:
value = type(value)(numpy.asarray(dataset))
return value
[docs]def load_npz(filename, obj):
"""Loads an object from the file in NPZ format.
This is a short-cut function to load from an `.npz` file that contains only
one object.
Args:
filename (str): Name of the file to be loaded.
obj: Object to be deserialized. It must support serialization protocol.
"""
with numpy.load(filename) as f:
d = NpzDeserializer(f)
d.load(obj)