summaryrefslogtreecommitdiffstats
path: root/tools/isolate/worker_pool.py
blob: 8edaee6e6ae0d85293c1b0134da307d5646ca0b4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright (c) 2012 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Implements a multithreaded worker pool oriented for mapping jobs with
thread-local result storage.
"""

import Queue
import sys
import time
import threading


class QueueWithTimeout(Queue.Queue):
  """Implements timeout support in join()."""

  # QueueWithTimeout.join: Arguments number differs from overridden method
  # pylint: disable=W0221
  def join(self, timeout=None):
    """Returns True if all tasks are finished."""
    if not timeout:
      return Queue.Queue.join(self)
    start = time.time()
    self.all_tasks_done.acquire()
    try:
      while self.unfinished_tasks:
        remaining = time.time() - start - timeout
        if remaining <= 0:
          break
        self.all_tasks_done.wait(remaining)
      return not self.unfinished_tasks
    finally:
      self.all_tasks_done.release()


class WorkerThread(threading.Thread):
  """Keeps the results of each task in a thread-local outputs variable."""
  def __init__(self, tasks, *args, **kwargs):
    super(WorkerThread, self).__init__(*args, **kwargs)
    self._tasks = tasks
    self.outputs = []
    self.exceptions = []

    self.daemon = True
    self.start()

  def run(self):
    """Runs until a None task is queued."""
    while True:
      task = self._tasks.get()
      if task is None:
        # We're done.
        return
      try:
        func, args, kwargs = task
        self.outputs.append(func(*args, **kwargs))
      except Exception:
        self.exceptions.append(sys.exc_info())
      finally:
        self._tasks.task_done()


class ThreadPool(object):
  def __init__(self, num_threads):
    self._tasks = QueueWithTimeout()
    self._workers = [
      WorkerThread(self._tasks, name='worker-%d' % i)
      for i in range(num_threads)
    ]

  def add_task(self, func, *args, **kwargs):
    """Adds a task, a function to be executed by a worker.

    The function's return value will be stored in the the worker's thread local
    outputs list.
    """
    self._tasks.put((func, args, kwargs))

  def join(self, progress=None, timeout=None):
    """Extracts all the results from each threads unordered."""
    if progress and timeout:
      while not self._tasks.join(timeout):
        progress.print_update()
    else:
      self._tasks.join()
    out = []
    for w in self._workers:
      if w.exceptions:
        raise w.exceptions[0][0], w.exceptions[0][1], w.exceptions[0][2]
      out.extend(w.outputs)
      w.outputs = []
    # Look for exceptions.
    return out

  def close(self):
    """Closes all the threads."""
    for _ in range(len(self._workers)):
      # Enqueueing None causes the worker to stop.
      self._tasks.put(None)
    for t in self._workers:
      t.join()

  def __enter__(self):
    """Enables 'with' statement."""
    return self

  def __exit__(self, exc_type, exc_value, traceback):
    """Enables 'with' statement."""
    self.close()


class Progress(object):
  """Prints progress and accepts updates thread-safely."""
  def __init__(self, size):
    self.last_printed_line = ''
    self.next_line = ''
    self.index = -1
    self.size = size
    self.start = time.time()
    self.lock = threading.Lock()
    self.update_item('')

  def update_item(self, name):
    with self.lock:
      self.index += 1
      self.next_line = '%d of %d (%.1f%%), %.1fs: %s' % (
            self.index,
            self.size,
            self.index * 100. / self.size,
            time.time() - self.start,
            name)

  def print_update(self):
    """Prints the current status."""
    with self.lock:
      if self.next_line == self.last_printed_line:
        return
      line = '\r%s%s' % (
          self.next_line,
          ' ' * max(0, len(self.last_printed_line) - len(self.next_line)))
      self.last_printed_line = self.next_line
    sys.stderr.write(line)

  def increase_count(self):
    with self.lock:
      self.size += 1