summaryrefslogtreecommitdiffstats
path: root/testing/legion/rpc_server.py
blob: 43b431707e851c0601058c401a3dcf3d0201c255 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright 2015 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""The task RPC server code.

This server is an XML-RPC server which serves code from
rpc_methods.RPCMethods.

This server will run until shutdown is called on the server object. This can
be achieved in 2 ways:

- Calling the Quit RPC method defined in RPCMethods
- Not receiving any calls within the idle_timeout_secs time.
"""

import logging
import threading
import time
import xmlrpclib
import SimpleXMLRPCServer
import SocketServer

#pylint: disable=relative-import
import common_lib
import rpc_methods


class RequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
  """Restricts access to only specified IP address.

  This call assumes the server is RPCServer.
  """

  def do_POST(self):
    """Verifies the task is authorized to perform RPCs."""
    if self.client_address[0] != self.server.authorized_address:
      logging.error('Received unauthorized RPC request from %s',
                    self.task_address[0])
      self.send_response(403)
      response = 'Forbidden'
      self.send_header('Content-type', 'text/plain')
      self.send_header('Content-length', str(len(response)))
      self.end_headers()
      self.wfile.write(response)
    else:
      return SimpleXMLRPCServer.SimpleXMLRPCRequestHandler.do_POST(self)


class RPCServer(SimpleXMLRPCServer.SimpleXMLRPCServer,
                SocketServer.ThreadingMixIn):
  """Restricts all endpoints to only specified IP addresses."""

  def __init__(self, authorized_address,
               idle_timeout_secs=common_lib.DEFAULT_TIMEOUT_SECS):
    SimpleXMLRPCServer.SimpleXMLRPCServer.__init__(
        self, (common_lib.SERVER_ADDRESS, common_lib.SERVER_PORT),
        allow_none=True, logRequests=False,
        requestHandler=RequestHandler)

    self.authorized_address = authorized_address
    self.idle_timeout_secs = idle_timeout_secs
    self.register_instance(rpc_methods.RPCMethods(self))

    self._shutdown_requested_event = threading.Event()
    self._rpc_received_event = threading.Event()
    self._idle_thread = threading.Thread(target=self._CheckForIdleQuit)

  def shutdown(self):
    """Shutdown the server.

    This overloaded method sets the _shutdown_requested_event to allow the
    idle timeout thread to quit.
    """
    self._shutdown_requested_event.set()
    SimpleXMLRPCServer.SimpleXMLRPCServer.shutdown(self)
    logging.info('Server shutdown complete')

  def serve_forever(self, poll_interval=0.5):
    """Serve forever.

    This overloaded method starts the idle timeout thread before calling
    serve_forever. This ensures the idle timer thread doesn't get started
    without the server running.

    Args:
      poll_interval: The interval to poll for shutdown.
    """
    logging.info('RPC server starting')
    self._idle_thread.start()
    SimpleXMLRPCServer.SimpleXMLRPCServer.serve_forever(self, poll_interval)

  def _dispatch(self, method, params):
    """Dispatch the call to the correct method with the provided params.

    This overloaded method adds logging to help trace connection and
    call problems.

    Args:
      method: The method name to call.
      params: A tuple of parameters to pass.

    Returns:
      The result of the parent class' _dispatch method.
    """
    logging.debug('Calling %s%s', method, params)
    self._rpc_received_event.set()
    return SimpleXMLRPCServer.SimpleXMLRPCServer._dispatch(self, method, params)

  def _CheckForIdleQuit(self):
    """Check for, and exit, if the server is idle for too long.

    This method must be run in a separate thread to avoid a deadlock when
    calling server.shutdown.
    """
    timeout = time.time() + self.idle_timeout_secs
    while time.time() < timeout:
      if self._shutdown_requested_event.is_set():
        # An external source called shutdown()
        return
      elif self._rpc_received_event.is_set():
        logging.debug('Resetting the idle timeout')
        timeout = time.time() + self.idle_timeout_secs
        self._rpc_received_event.clear()
      time.sleep(1)
    # We timed out, kill the server
    logging.warning('Shutting down the server due to the idle timeout')
    self.shutdown()