summaryrefslogtreecommitdiffstats
path: root/third_party/typ/typ/tests/pool_test.py
blob: fdb3f399c0306ec858baf35e75bd59d082e15572 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright 2014 Dirk Pranke. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typ import test_case
from typ.host import Host
from typ.pool import make_pool, _MessageType, _ProcessPool, _loop


def _pre(host, worker_num, context):  # pylint: disable=W0613
    context['pre'] = True
    return context


def _post(context):
    context['post'] = True
    return context


def _echo(context, msg):
    return '%s/%s/%s' % (context['pre'], context['post'], msg)


def _error(context, msg):  # pylint: disable=W0613
    raise Exception('_error() raised Exception')


def _interrupt(context, msg):  # pylint: disable=W0613
    raise KeyboardInterrupt()


def _stub(*args):  # pylint: disable=W0613
    return None


class TestPool(test_case.TestCase):

    def run_basic_test(self, jobs):
        host = Host()
        context = {'pre': False, 'post': False}
        pool = make_pool(host, jobs, _echo, context, _pre, _post)
        pool.send('hello')
        pool.send('world')
        msg1 = pool.get()
        msg2 = pool.get()
        pool.close()
        final_contexts = pool.join()
        self.assertEqual(set([msg1, msg2]),
                         set(['True/False/hello',
                              'True/False/world']))
        expected_context = {'pre': True, 'post': True}
        expected_final_contexts = [expected_context for _ in range(jobs)]
        self.assertEqual(final_contexts, expected_final_contexts)

    def run_through_loop(self, callback=None, pool=None):
        callback = callback or _stub
        if pool:
            host = pool.host
        else:
            host = Host()
            pool = _ProcessPool(host, 0, _stub, None, _stub, _stub)
            pool.send('hello')

        worker_num = 1
        _loop(pool.requests, pool.responses, host, worker_num, callback,
              None, _stub, _stub, should_loop=False)
        return pool

    def test_async_close(self):
        host = Host()
        pool = make_pool(host, 1, _echo, None, _stub, _stub)
        pool.join()

    def test_basic_one_job(self):
        self.run_basic_test(1)

    def test_basic_two_jobs(self):
        self.run_basic_test(2)

    def test_join_discards_messages(self):
        host = Host()
        context = {'pre': False, 'post': False}
        pool = make_pool(host, 2, _echo, context, _pre, _post)
        pool.send('hello')
        pool.close()
        pool.join()
        self.assertEqual(len(pool.discarded_responses), 1)

    def test_join_gets_an_error(self):
        host = Host()
        pool = make_pool(host, 2, _error, None, _stub, _stub)
        pool.send('hello')
        pool.close()
        try:
            pool.join()
        except Exception as e:
            self.assertIn('_error() raised Exception', str(e))

    def test_join_gets_an_interrupt(self):
        host = Host()
        pool = make_pool(host, 2, _interrupt, None, _stub, _stub)
        pool.send('hello')
        pool.close()
        self.assertRaises(KeyboardInterrupt, pool.join)

    def test_loop(self):
        pool = self.run_through_loop()
        resp = pool.get()
        self.assertEqual(resp, None)
        pool.requests.put((_MessageType.Close, None))
        pool.close()
        self.run_through_loop(pool=pool)
        pool.join()

    def test_loop_fails_to_respond(self):
        # This tests what happens if _loop() tries to send a response
        # on a closed queue; we can't simulate this directly through the
        # api in a single thread.
        pool = self.run_through_loop()
        pool.requests.put((_MessageType.Request, None))
        pool.requests.put((_MessageType.Close, None))
        self.run_through_loop(pool=pool)
        pool.join()

    def test_loop_get_raises_error(self):
        pool = self.run_through_loop(_error)
        self.assertRaises(Exception, pool.get)
        pool.requests.put((_MessageType.Close, None))
        pool.close()
        pool.join()

    def test_loop_get_raises_interrupt(self):
        pool = self.run_through_loop(_interrupt)
        self.assertRaises(KeyboardInterrupt, pool.get)
        pool.requests.put((_MessageType.Close, None))
        pool.close()
        pool.join()

    def test_pickling_errors(self):
        def unpicklable_fn():  # pragma: no cover
            pass

        host = Host()
        jobs = 2
        self.assertRaises(ValueError, make_pool,
                          host, jobs, _stub, unpicklable_fn, None, None)
        self.assertRaises(ValueError, make_pool,
                          host, jobs, _stub, None, unpicklable_fn, None)
        self.assertRaises(ValueError, make_pool,
                          host, jobs, _stub, None, None, unpicklable_fn)

    def test_no_close(self):
        host = Host()
        context = {'pre': False, 'post': False}
        pool = make_pool(host, 2, _echo, context, _pre, _post)
        final_contexts = pool.join()
        self.assertEqual(final_contexts, [])