aboutsummaryrefslogtreecommitdiffstats
path: root/tests/_unittest_compat.py
blob: 2d4985dd674d80926474d185f5d4131a8a4a6d39 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# coding: utf-8
from __future__ import unicode_literals, division, absolute_import, print_function

import sys
import unittest
import re


_non_local = {'patched': False}


def patch():
    if not sys.version_info < (2, 7):
        return

    if _non_local['patched']:
        return

    unittest.TestCase.assertIsInstance = _assert_is_instance
    unittest.TestCase.assertRaises = _assert_raises
    unittest.TestCase.assertRaisesRegexp = _assert_raises_regexp
    _non_local['patched'] = True


def _assert_is_instance(self, obj, cls, msg=None):
    """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
    default message."""
    if not isinstance(obj, cls):
        if not msg:
            msg = '%s is not an instance of %r' % (obj, cls)
        self.fail(msg)


def _assert_raises(self, expected_exception, callableObj=None, *args, **kwargs):  # noqa
    context = _AssertRaisesContext(expected_exception, self)
    if callableObj is None:
        return context
    with context:
        callableObj(*args, **kwargs)


def _assert_raises_regexp(self, expected_exception, expected_regexp, callable_obj=None, *args, **kwargs):
    if expected_regexp is not None:
        expected_regexp = re.compile(expected_regexp)
    context = _AssertRaisesContext(expected_exception, self, expected_regexp)
    if callable_obj is None:
        return context
    with context:
        callable_obj(*args, **kwargs)


class _AssertRaisesContext(object):
    """A context manager used to implement TestCase.assertRaises* methods."""

    def __init__(self, expected, test_case, expected_regexp=None):
        self.expected = expected
        self.failureException = test_case.failureException
        self.expected_regexp = expected_regexp

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, tb):
        if exc_type is None:
            try:
                exc_name = self.expected.__name__
            except AttributeError:
                exc_name = str(self.expected)
            raise self.failureException(
                "{0} not raised".format(exc_name))
        if not issubclass(exc_type, self.expected):
            # let unexpected exceptions pass through
            return False
        self.exception = exc_value  # store for later retrieval
        if self.expected_regexp is None:
            return True

        expected_regexp = self.expected_regexp
        if not expected_regexp.search(str(exc_value)):
            raise self.failureException(
                '"%s" does not match "%s"' %
                (expected_regexp.pattern, str(exc_value))
            )
        return True