summaryrefslogtreecommitdiffstats
path: root/testing/legion/task_controller.py
blob: 5522a2b2bca7a104775585c16f7ff8906fdcfa89 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Copyright 2015 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Defines the task controller library."""

import argparse
import datetime
import logging
import os
import socket
import subprocess
import sys
import tempfile
import threading
import xmlrpclib

#pylint: disable=relative-import
import common_lib
import process
import ssl_util

ISOLATE_PY = os.path.join(common_lib.SWARMING_DIR, 'isolate.py')
SWARMING_PY = os.path.join(common_lib.SWARMING_DIR, 'swarming.py')


class Error(Exception):
  pass


class ConnectionTimeoutError(Error):
  pass


class TaskController(object):
  """Provisions, configures, and controls a task machine.

  This class is an abstraction of a physical task machine. It provides an
  end to end API for controlling a task machine. Operations on the task machine
  are performed using the instance's "rpc" property. A simple end to end
  scenario is as follows:

  task = TaskController(...)
  task.Create()
  task.WaitForConnection()
  proc = task.rpc.subprocess.Popen(['ls'])
  print task.rpc.subprocess.GetStdout(proc)
  task.Release()
  """

  _task_count = 0
  _tasks = []

  def __init__(self, isolated_hash, dimensions, priority=100,
               idle_timeout_secs=common_lib.DEFAULT_TIMEOUT_SECS,
               connection_timeout_secs=common_lib.DEFAULT_TIMEOUT_SECS,
               verbosity='ERROR', name=None, run_id=None):
    assert isinstance(dimensions, dict)
    type(self)._tasks.append(self)
    type(self)._task_count += 1
    self.verbosity = verbosity
    self._name = name or 'Task%d' % type(self)._task_count
    self._priority = priority
    self._isolated_hash = isolated_hash
    self._idle_timeout_secs = idle_timeout_secs
    self._dimensions = dimensions
    self._connect_event = threading.Event()
    self._connected = False
    self._ip_address = None
    self._otp = self._CreateOTP()
    self._rpc = None

    run_id = run_id or datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    self._task_name = '%s/%s/%s' % (
        os.path.splitext(sys.argv[0])[0], self._name, run_id)

    parser = argparse.ArgumentParser()
    parser.add_argument('--isolate-server')
    parser.add_argument('--swarming-server')
    parser.add_argument('--task-connection-timeout-secs',
                        default=common_lib.DEFAULT_TIMEOUT_SECS)
    args, _ = parser.parse_known_args()

    self._isolate_server = args.isolate_server
    self._swarming_server = args.swarming_server
    self._connection_timeout_secs = (connection_timeout_secs or
                                    args.task_connection_timeout_secs)

  @property
  def name(self):
    return self._name

  @property
  def otp(self):
    return self._otp

  @property
  def connected(self):
    return self._connected

  @property
  def connect_event(self):
    return self._connect_event

  @property
  def rpc(self):
    return self._rpc

  @property
  def verbosity(self):
    return self._verbosity

  @verbosity.setter
  def verbosity(self, level):
    """Sets the verbosity level as a string.

    Either a string ('INFO', 'DEBUG', etc) or a logging level (logging.INFO,
    logging.DEBUG, etc) is allowed.
    """
    assert isinstance(level, (str, int))
    if isinstance(level, int):
      level = logging.getLevelName(level)
    self._verbosity = level  #pylint: disable=attribute-defined-outside-init

  @classmethod
  def ReleaseAllTasks(cls):
    for task in cls._tasks:
      task.Release()

  def Process(self, cmd, verbose=False, detached=False, cwd=None):
    return process.ControllerProcessWrapper(
        self.rpc, cmd, verbose, detached, cwd)

  def _CreateOTP(self):
    """Creates the OTP."""
    controller_name = socket.gethostname()
    test_name = os.path.basename(sys.argv[0])
    creation_time = datetime.datetime.utcnow()
    otp = 'task:%s controller:%s test:%s creation:%s' % (
        self._name, controller_name, test_name, creation_time)
    return otp

  def Create(self):
    """Creates the task machine."""
    logging.info('Creating %s', self.name)
    self._connect_event.clear()
    self._ExecuteSwarming()

  def WaitForConnection(self):
    """Waits for the task machine to connect.

    Raises:
      ConnectionTimeoutError if the task doesn't connect in time.
    """
    logging.info('Waiting for %s to connect with a timeout of %d seconds',
                 self._name, self._connection_timeout_secs)
    self._connect_event.wait(self._connection_timeout_secs)
    if not self._connect_event.is_set():
      raise ConnectionTimeoutError('%s failed to connect' % self.name)

  def Release(self):
    """Quits the task's RPC server so it can release the machine."""
    if self._rpc is not None and self._connected:
      logging.info('Releasing %s', self._name)
      try:
        self._rpc.Quit()
      except (socket.error, xmlrpclib.Fault):
        logging.error('Unable to connect to %s to call Quit', self.name)
      self._rpc = None
      self._connected = False

  def _ExecuteSwarming(self):
    """Executes swarming.py."""
    cmd = [
        'python',
        SWARMING_PY,
        'trigger',
        self._isolated_hash,
        '--priority', str(self._priority),
        '--task-name', self._task_name,
        ]

    if self._isolate_server:
      cmd.extend(['--isolate-server', self._isolate_server])
    if self._swarming_server:
      cmd.extend(['--swarming', self._swarming_server])
    for key, value in self._dimensions.iteritems():
      cmd.extend(['--dimension', key, value])

    cmd.extend([
        '--',
        '--controller', common_lib.MY_IP,
        '--otp', self._otp,
        '--verbosity', self._verbosity,
        '--idle-timeout', str(self._idle_timeout_secs),
        ])

    self._ExecuteProcess(cmd)

  def _ExecuteProcess(self, cmd):
    """Executes a process, waits for it to complete, and checks for success."""
    logging.debug('Running %s', ' '.join(cmd))
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    _, stderr = p.communicate()
    if p.returncode != 0:
      raise Error(stderr)

  def OnConnect(self, ip_address):
    """Receives task ip address on connection."""
    self._ip_address = ip_address
    self._connected = True
    self._rpc = ssl_util.SslRpcServer.Connect(self._ip_address)
    logging.info('%s connected from %s', self._name, ip_address)
    self._connect_event.set()