#!/usr/bin/env python """Extension of the built-in unittest framework for floating-point comparisons. Specific Extensions: assertFloatEqual, assertFloatEqualAbs, and assertFloatEqualRel give fine- grained control over how floating point numbers (or lists thereof) are tested for equality. assertContains and assertNotContains give more helpful error messages when testing whether an observed item is present or absent in a set of possiblities. Ditto assertGreaterThan, assertLessThan, assertBetween, and assertIsProb (which is a special case of assertBetween requiring the result to between 0 and 1). assertSameItems and assertEqualItems test the items in a list for pairwise identity and equality respectively (i.e. the observed and expected values must have the same number of each item, though the order can differ); assertNotEqualItems verifies that two lists do not contain equal sets of items. assertSimilarMeans and assertSimilarFreqs allow you to test stochastic results by setting an explicit P-value and checking that the result is not improbable given the expected P-value. Please use these instead of guessing confidence intervals! The major advantage is that you can reset the P-value gloabally over the whole test suite, so that rare failures don't occur every time. assertIsPermutation checks that you get a permutation of an expected result that differs from the original result, repeating the test a specified number of times before giving up and assuming that the result is always the same. """ #from contextlib import contextmanager import numpy; from numpy import testing, array, asarray, ravel, zeros, \ logical_and, logical_or, isfinite from unittest import main, TestCase as orig_TestCase, TestSuite, findTestCases from cogent.util.misc import recursive_flatten from cogent.maths.stats.test import t_two_sample, G_ind __author__ = "Rob Knight" __copyright__ = "Copyright 2007-2012, The Cogent Project" __credits__ = ["Rob Knight", "Peter Maxwell", "Sandra Smit", "Zongzhi Liu", "Micah Hamady", "Daniel McDonald"] __license__ = "GPL" __version__ = "1.5.3-dev" __maintainer__ = "Rob Knight" __email__ = "rob@spot.colorado.edu" __status__ = "Production" ## SUPPORT2425 #@contextmanager #def numpy_err(**kw): # """a numpy err context manager. # **kw: pass to numpy.seterr(all=None, divide=None, over=None, under=None, # invalid=None) # Example: # with numpy_err(divide='raise'): # self.assertRaises(FloatingPointError, log, 0) # """ # ori_err = numpy.geterr() # numpy.seterr(**kw) # try: yield None # finally: numpy.seterr(**ori_err) class FakeRandom(object): """Drop-in substitute for random.random that provides items from list.""" def __init__(self, data, circular=False): """Returns new FakeRandom object, using list of items in data. circular: if True (default is False), wraps the list around. Otherwise, raises IndexError when we run off the end of the list. WARNING: data must always be iterable, even if it's a single item. """ self._data = data self._ptr = -1 self._circular = circular def __call__(self, *args, **kwargs): """Returns next item from the list in self._data. Raises IndexError when we run out of data. """ self._ptr += 1 #wrap around if circular if self._circular: if self._ptr >= len(self._data): self._ptr = 0 return self._data[self._ptr] class TestCase(orig_TestCase): """Adds some additional utility methods to unittest.TestCase. Notably, adds facilities for dealing with floating point numbers, and some common templates for replicated tests. BEWARE: Do not start any method with 'test' unless you want it to actually run as a test suite in every instance! """ _suite_pvalue = None # see TestCase._set_suite_pvalue() def _get_values_from_matching_dicts(self, d1, d2): """Gets corresponding values from matching dicts""" if set(d1) != set (d2): return None return d1.values(), [d2[k] for k in d1] #might not be in same order def errorCheck(self, call, known_errors): """Applies function to (data, error) tuples, checking for error """ for (data, error) in known_errors: self.assertRaises(error, call, data) def valueCheck(self, call, known_values, arg_prefix='', eps=None): """Applies function to (data, expected) tuples, treating data as args """ for (data, expected) in known_values: observed = eval('call(' + arg_prefix + 'data)') try: allowed_diff = float(eps) except TypeError: self.assertEqual(observed, expected) else: self.assertFloatEqual(observed, expected, allowed_diff) def assertFloatEqualRel(self, obs, exp, eps=1e-6): """Tests whether two floating point numbers/arrays are approx. equal. Checks whether the distance is within epsilon relative to the value of the sum of observed and expected. Use this method when you expect the difference to be small relative to the magnitudes of the observed and expected values. Note: for arbitrary objects, need to compare the specific attribute that's numeric, not the whole object, using this method. """ #do array check first #note that we can't use array ops to combine, because we need to check #at each element whether the expected is zero to do the test to avoid #floating point error. #WARNING: numpy iterates over objects that are not regular Python #floats/ints, so need to explicitly catch scalar values and prevent #cast to array if we want the exact object to print out correctly. is_array = False if hasattr(obs, 'keys') and hasattr(exp, 'keys'): #both dicts? result = self._get_values_from_matching_dicts(obs, exp) if result: obs, exp = result else: try: iter(obs) iter(exp) except TypeError: obs = [obs] exp = [exp] else: try: arr_obs = array(obs) arr_exp = array(exp) arr_diff = arr_obs - arr_exp if arr_obs.shape != arr_exp.shape: self.fail("Wrong shape: Got %s, but expected %s" % \ (`obs`, `exp`)) obs = arr_obs.ravel() exp = arr_exp.ravel() is_array=True except (TypeError, ValueError): pass # shape mismatch can still get by... # explict cast is to work around bug in certain versions of numpy # installed version on osx 10.5 if asarray(obs, object).shape != asarray(exp, object).shape: self.fail("Wrong shape: Got %s, but expected %s" % (obs, exp)) for observed, expected in zip(obs, exp): #try the cheap comparison first if observed == expected: continue try: sum = float(observed + expected) diff = float(observed - expected) if (sum == 0): if is_array: self.failIf(abs(diff) > abs(eps), \ "Got %s, but expected %s (diff was %s)" % \ (`arr_obs`, `arr_exp`, `arr_diff`)) else: self.failIf(abs(diff) > abs(eps), \ "Got %s, but expected %s (diff was %s)" % \ (`observed`, `expected`, `diff`)) else: if is_array: self.failIf(abs(diff/sum) > abs(eps), \ "Got %s, but expected %s (diff was %s)" % \ (`arr_obs`, `arr_exp`, `arr_diff`)) else: self.failIf(abs(diff/sum) > abs(eps), \ "Got %s, but expected %s (diff was %s)" % \ (`observed`, `expected`, `diff`)) except (TypeError, ValueError, AttributeError, NotImplementedError): self.fail("Got %s, but expected %s" % \ (`observed`, `expected`)) def assertFloatEqualAbs(self, obs, exp, eps=1e-6): """ Tests whether two floating point numbers are approximately equal. Checks whether the absolute value of (a - b) is within epsilon. Use this method when you expect that one of the values should be very small, and the other should be zero. """ #do array check first #note that we can't use array ops to combine, because we need to check #at each element whether the expected is zero to do the test to avoid #floating point error. if hasattr(obs, 'keys') and hasattr(exp, 'keys'): #both dicts? result = self._get_values_from_matching_dicts(obs, exp) if result: obs, exp = result else: try: iter(obs) iter(exp) except TypeError: obs = [obs] exp = [exp] else: try: arr_obs = array(obs) arr_exp = array(exp) if arr_obs.shape != arr_exp.shape: self.fail("Wrong shape: Got %s, but expected %s" % \ (`obs`, `exp`)) diff = arr_obs - arr_exp self.failIf(abs(diff).max() > eps, \ "Got %s, but expected %s (diff was %s)" % \ (`obs`, `exp`, `diff`)) return except (TypeError, ValueError): pass #only get here if array comparison failed for observed, expected in zip(obs, exp): #cheap comparison first if observed == expected: continue try: diff = observed - expected self.failIf(abs(diff) > abs(eps), "Got %s, but expected %s (diff was %s)" % \ (`observed`, `expected`, `diff`)) except (TypeError, ValueError, AttributeError, NotImplementedError): self.fail("Got %s, but expected %s" % \ (`observed`, `expected`)) def assertFloatEqual(self, obs, exp, eps=1e-6, rel_eps=None, \ abs_eps=None): """Tests whether two floating point numbers are approximately equal. If one of the arguments is zero, tests the absolute magnitude of the difference; otherwise, tests the relative magnitude. Use this method as a reasonable default. """ obs = numpy.asarray(obs, dtype='O') exp = numpy.asarray(exp, dtype='O') obs = numpy.ravel(obs) exp = numpy.ravel(exp) if obs.shape != exp.shape: self.fail("Shape mismatch. Got, %s but expected %s" % (obs, exp)) for observed, expected in zip(obs, exp): if self._is_equal(observed, expected): continue try: rel_eps = rel_eps or eps abs_eps = abs_eps or eps if (observed == 0) or (expected == 0): self.assertFloatEqualAbs(observed, expected, abs_eps) else: self.assertFloatEqualRel(observed, expected, rel_eps) except (TypeError, ValueError, AttributeError, NotImplementedError): self.fail("Got %s, but expected %s" % \ (`observed`, `expected`)) def _is_equal(self, observed, expected): """Returns True if observed and expected are equal, False otherwise.""" #errors to catch: TypeError when obs is None tolist_errors = (AttributeError, ValueError, TypeError) try: obs = observed.tolist() except tolist_errors: obs = observed try: exp = expected.tolist() except tolist_errors: exp = expected return obs == exp def failUnlessEqual(self, observed, expected, msg=None): """Fail if the two objects are unequal as determined by != Overridden to make error message enforce order of observed, expected. Use numpy.testing.assert_equal if ValueError, TypeError raised. """ try: if not self._is_equal(observed, expected): raise self.failureException, \ (msg or 'Got %s, but expected %s' % (`observed`, `expected`)) except (ValueError, TypeError), e: #The truth value of an array with more than one element is #ambiguous. Use a.any() or a.all() #descriptor 'tolist' of 'numpy.generic' object needs an argument testing.assert_equal(observed, expected) def failIfEqual(self, observed, expected, msg=None): """Fail if the two objects are equal as determined by ==""" try: self.failUnlessEqual(observed, expected) except self.failureException: pass else: raise self.failureException, \ (msg or 'Observed %s and expected %s: shouldn\'t test equal'\ % (`observed`, `expected`)) #following needed to get our version instead of unittest's assertEqual = assertEquals = failUnlessEqual assertNotEqual = assertNotEquals = failIfEqual def assertEqualItems(self, observed, expected, msg=None): """Fail if the two items contain unequal elements""" obs_items = list(observed) exp_items = list(expected) if len(obs_items) != len(exp_items): raise self.failureException, \ (msg or 'Observed and expected are different lengths: %s and %s' \ % (len(obs_items), len(exp_items))) obs_items.sort() exp_items.sort() for index, (obs, exp) in enumerate(zip(obs_items, exp_items)): if obs != exp: raise self.failureException, \ (msg or 'Observed %s and expected %s at sorted index %s' \ % (obs, exp, index)) def assertSameItems(self, observed, expected, msg=None): """Fail if the two items contain non-identical elements""" obs_items = list(observed) exp_items = list(expected) if len(obs_items) != len(exp_items): raise self.failureException, \ (msg or 'Observed and expected are different lengths: %s and %s' \ % (len(obs_items), len(exp_items))) obs_ids = [(id(i), i) for i in obs_items] exp_ids = [(id(i), i) for i in exp_items] obs_ids.sort() exp_ids.sort() for index, (obs, exp) in enumerate(zip(obs_ids, exp_ids)): o_id, o = obs e_id, e = exp if o_id != e_id: #i.e. the ids are different raise self.failureException, \ (msg or \ 'Observed %s <%s> and expected %s <%s> at sorted index %s' \ % (o, o_id, e, e_id, index)) def assertNotEqualItems(self, observed, expected, msg=None): """Fail if the two items contain only equal elements when sorted""" try: self.assertEqualItems(observed, expected, msg) except: pass else: raise self.failureException, \ (msg or 'Observed %s has same items as %s'%(`observed`, `expected`)) def assertContains(self, observed, item, msg=None): """Fail if item not in observed""" try: if item in observed: return except (TypeError, ValueError): pass raise self.failureException, \ (msg or 'Item %s not found in %s' % (`item`, `observed`)) def assertNotContains(self, observed, item, msg=None): """Fail if item in observed""" try: if item not in observed: return except (TypeError, ValueError): return raise self.failureException, \ (msg or 'Item %s should not have been in %s' % (`item`, `observed`)) def assertGreaterThan(self, observed, value, msg=None): """Fail if observed is <= value""" try: if value is None or observed is None: raise ValueError if (asarray(observed) > value).all(): return except: pass raise self.failureException, \ (msg or 'Observed %s has elements <= %s' % (`observed`, `value`)) def assertLessThan(self, observed, value, msg=None): """Fail if observed is >= value""" try: if value is None or observed is None: raise ValueError if (asarray(observed) < value).all(): return except: pass raise self.failureException, \ (msg or 'Observed %s has elements >= %s' % (`observed`, `value`)) def assertIsBetween(self, observed, min_value, max_value, msg=None): """Fail if observed is not between min_value and max_value""" try: if min_value is None or max_value is None or observed is None: raise ValueError if min_value >= max_value: raise ValueError if logical_and(asarray(observed) < max_value, asarray(observed) > min_value).all(): return except: pass raise self.failureException, \ (msg or 'Observed %s has elements not between %s, %s' % \ (`observed`, `min_value`, `max_value`)) def assertIsNotBetween(self, observed, min_value, max_value, msg=None): """Fail if observed is between min_value and max_value""" try: if min_value is None or max_value is None or observed is None: raise ValueError if min_value >= max_value: raise ValueError if logical_or(asarray(observed) >= max_value, asarray(observed) <= min_value).all(): return except: pass raise self.failureException, \ (msg or 'Observed %s has elements between %s, %s' % \ (`observed`, `min_value`, `max_value`)) def assertIsProb(self, observed, msg=None): """Fail is observed is not between 0.0 and 1.0""" try: if observed is None: raise ValueError if (asarray(observed) >= 0.0).all() and \ (asarray(observed) <= 1.0).all(): return except: pass raise self.failureException, \ (msg or 'Observed %s has elements that are not probs' % (`observed`)) def _set_suite_pvalue(self, pvalue): """Sets the test suite pvalue to be used in similarity tests This value is by default None. The pvalue used in this case is specified in the test module itself. The purpose of this method is to set the pvalue to be used when running a massive test suite """ self._suite_pvalue = pvalue def assertSimilarMeans(self, observed, expected, pvalue=0.01, msg=None): """Fail if observed p is lower than pvalue""" if self._suite_pvalue: pvalue = self._suite_pvalue observed, expected = asarray(observed), asarray(expected) t, p = t_two_sample(observed, expected) if p > pvalue: return elif p is None or not isfinite(p): #handle case where all elements were the same if not observed.shape: observed = observed.reshape((1,)) if not expected.shape: expected = expected.reshape((1,)) if observed[0] == expected[0]: return else: raise self.failureException, \ (msg or 'p-value %s, t-test p %s' % (`pvalue`, `p`)) def assertSimilarFreqs(self, observed, expected, pvalue=0.01, msg=None): """Fail if observed p is lower than pvalue""" if self._suite_pvalue: pvalue = self._suite_pvalue obs_ravel = ravel(asarray(observed)) exp_ravel = ravel(asarray(expected)) m = zeros((2,len(obs_ravel))) m[0,:] = obs_ravel m[1,:] = exp_ravel G, p = G_ind(m) if p > pvalue: return else: raise self.failureException, \ (msg or 'p-value %s, G-test p %s' % (`pvalue`, `p`)) def assertIsPermutation(self, observed, items, msg=None): """Fail if observed is not a permutation of items""" try: self.assertEqualItems(observed, items) self.assertNotEqual(observed, items) return except: pass raise self.failureException, \ (msg or 'Observed %s is not a different permutation of items %s' % \ (`observed`, `items`)) def assertSameObj(self, observed, expected, msg=None): """Fail if 'observed is not expected'""" try: if observed is expected: return except: pass raise self.failureException, \ (msg or 'Observed %s is not the same as expected %s' % \ (`observed`, `expected`)) def assertNotSameObj(self, observed, expected, msg=None): """Fail if 'observed is expected'""" try: if observed is not expected: return except: pass raise self.failureException, \ (msg or 'Observed %s is the same as expected %s' % \ (`observed`, `expected`))