# coding: utf-8
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
# Distributed under the terms of "New BSD License", see the LICENSE file.
from pyiron_contrib.protocol.utils.misc import LoggerMixin, ensure_iterable, Registry
from pyiron_atomistics.atomistics.structure.atoms import Atoms
import numpy as np
import logging
"""
Classes to override compare
"""
__author__ = "Dominik Gehringer, Liam Huber"
__copyright__ = "Copyright 2019, Max-Planck-Institut für Eisenforschung GmbH " \
"- Computational Materials Design (CM) Department"
__version__ = "0.0"
__maintainer__ = "Liam Huber"
__email__ = "huber@mpie.de"
__status__ = "development"
__date__ = "December 10, 2019"
try:
from xxhash import xxh64_hexdigest
hashfunction = xxh64_hexdigest
except ImportError:
from hashlib import sha1
logging.getLogger('pyiron_contrib.protocol.generic').debug('Falling back to SHA1 hashing')
hashfunction = sha1
ensure_iterable_tuple = lambda o: tuple(ensure_iterable)
[docs]class Comparer(LoggerMixin, metaclass=Registry):
"""
Class is aware of its subclasses. Subclasses must have a `type` attribute of type `type`.
Compares two objects, where the behaviour can be overridden, and automatically determines the type.
"""
def __init__(self, obj):
super(Comparer, self).__init__()
self._object = obj
self._cls = type(obj)
# this is a private object, one ought not access it
self.__registry_cache = {}
if not isinstance(self._object, self._cls):
raise TypeError
@property
def object(self):
return self._object
[docs] def compatible_types(self, a, b):
"""
Tests whether an equality comparison can be made between objects of type `a` and `b` by checking that these are
the same type, or both either int or float.
Args:
a (type): The first type to be tested.
b (type): The second type to be tested.
Returns:
(bool): True when the two types are the same or both belong to int and float.
"""
if a != b and not self.both_are_int_or_float(a, b):
return False
else:
return True
[docs] @staticmethod
def both_are_int_or_float(a, b):
valid_list = [int, float]
return a in valid_list and b in valid_list
def _equals(self, b):
if isinstance(b, Comparer):
if not self.compatible_types(b._cls, self._cls):
return False
else:
b = b._object
elif not self.compatible_types(type(b), self._cls):
self.logger.warning("Comparer failed due to type difference between {} and {}".format(
type(b).__name__, type(self._cls).__name__
))
return False
comparer = self._get_comparer()
return self.default(b) if comparer is None else comparer(self.object).equals(b)
[docs] def default(self, b):
return self._object == b
def _get_comparer(self):
# one can create on the fly subclasses, therefore we have to check it
if len(self.__registry_cache) != len(self.registry):
for cls in self.registry:
if not hasattr(cls, 'type'):
raise TypeError('The subclass "{}" must have a "type" attribute'.format(cls.__name__))
self.__registry_cache[cls.type] = cls
# registry is updated
# check if we can resolve it directly
if self._cls in self.__registry_cache:
return self.__registry_cache[self._cls]
else:
# it could be the subclass of one of the entries
for k in self.__registry_cache.keys():
if issubclass(self._cls, k):
return self.__registry_cache[k]
return None
[docs] def equals(self, b):
return self.default(b)
def __eq__(self, other):
return self._equals(other)
[docs]class NumpyArrayComparer(Comparer):
"""
Used to compare numpy arrays.
"""
type = np.ndarray
[docs] @staticmethod
def get_machine_epsilon(a):
"""
Returns the machine inaccuracy for the datatype of `a`.
Args:
a: (np.ndarray) the array
Returns: (float or None) the machine epsilon or None if it is an exact datatype
"""
try:
epsilon = np.finfo(a.dtype).eps
except ValueError:
epsilon = None
return epsilon
[docs] def equals(self, b):
fudge_factor = 10
epsilon = self.get_machine_epsilon(b)
# check if the datatype is inexact at all
inexact = epsilon is not None
if inexact:
return self.object.shape == b.shape and np.allclose(self.object, b, atol=fudge_factor*epsilon, rtol=0)
else:
# it is an exact data type such as int
return np.array_equal(self.object, b)
[docs]class AtomsComparer(Comparer):
"""
Used to compare pyiron Atoms objects.
"""
type = Atoms
[docs] def equals(self, b):
assert isinstance(b, Atoms)
assert isinstance(self.object, Atoms)
index_spec_mapping = lambda atoms: {site.index: site.symbol for site in atoms}
# compare structures
# https://github.com/pyiron/pyiron/blob/c447ffb4f1e003d0ebaced50a12def46beefab4f/pyiron/atomistics/job/interactive.py
conditions = [
len(self.object) == len(b),
Comparer(self.object.cell.array) == b.cell.array,
Comparer(self.object.get_scaled_positions()) == b.get_scaled_positions(),
Comparer(self.object.get_initial_magnetic_moments()) == b.get_initial_magnetic_moments(),
index_spec_mapping(self.object) == index_spec_mapping(b)
]
return all(conditions)
[docs]class ListComparer(Comparer):
"""
Used to compare lists.
"""
type = list
[docs] def equals(self, b):
assert isinstance(b, list)
assert isinstance(self.object, list)
conditions = [
len(self.object) == len(b),
all([Comparer(val) == last_val for val, last_val in zip(self.object, b)])
]
return all(conditions)