summaryrefslogtreecommitdiffstats
path: root/tools/playback_benchmark/proxy_handler.py
blob: 05da0780de26e0de24b3f1c3c43aabfd185dc29b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python
#
# Copyright 2010 Google Inc. All Rights Reserved.

"""HTTP proxy request handler with SSL support.

  RequestHandler: Utility class for parsing HTTP requests.
  ProxyHandler: HTTP proxy handler.
"""

import BaseHTTPServer
import cgi
import OpenSSL
import os
import socket
import SocketServer
import sys
import traceback
import urlparse


class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
  """Class for reading HTTP requests and writing HTTP responses"""

  protocol_version = "HTTP/1.1"
  request_version = protocol_version

  class HTTPRequestException(Exception): pass

  def __init__(self, rfile, wfile, server):
    self.rfile = rfile
    self.wfile = wfile
    self.server = server

  def ReadRequest(self):
    "Reads and parses single HTTP request from self.rfile"

    self.raw_requestline = self.rfile.readline()
    if not self.raw_requestline:
      self.close_connection = 1
      raise HTTPRequestException('failed to read request line')
    if not self.parse_request():
      raise HTTPRequestException('failed to parse request')
    self.headers = dict(self.headers)
    self.body = None
    if 'content-length' in self.headers:
      self.body = self.rfile.read(int(self.headers['content-length']))

  def log_message(self, format, *args):
    pass


class ProxyHandler(BaseHTTPServer.BaseHTTPRequestHandler):
  "Request handler class for proxy server"

  server_version = "PlaybackProxy/0.0.1"
  protocol_version = "HTTP/1.1"

  def do_CONNECT(self):
    "Handles CONNECT HTTP request"

    server = self.path.split(':')[0]
    certificate_file = os.path.join(self.certificate_directory, server)
    if not os.path.isfile(certificate_file):
      sys.stderr.write('request to connect %s is ignored\n' % server)
      self.send_response(501)
      self.send_header('Proxy-agent', self.version_string())
      self.end_headers()
      return

    # Send confirmation to browser.
    self.send_response(200, 'Connection established')
    self.send_header('Proxy-agent', self.version_string())
    self.end_headers()

    # Create SSL context.
    context = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
    context.use_privatekey_file(certificate_file)
    context.use_certificate_file(certificate_file)

    # Create and initialize SSL connection atop of tcp socket.
    ssl_connection = OpenSSL.SSL.Connection(context, self.connection)
    ssl_connection.set_accept_state()
    ssl_connection.do_handshake()
    ssl_rfile = socket._fileobject(ssl_connection, "rb", self.rbufsize)
    ssl_wfile = socket._fileobject(ssl_connection, "wb", self.wbufsize)

    # Handle http requests coming from ssl_connection.
    handler = RequestHandler(ssl_rfile, ssl_wfile, self.path)
    try:
      handler.close_connection = 1
      while True:
        handler.ReadRequest()
        self.driver.ProcessRequest(handler)
        if handler.close_connection: break
    except (OpenSSL.SSL.SysCallError, OpenSSL.SSL.ZeroReturnError):
      pass
    finally:
      self.close_connection = 1

  def do_GET(self):
    self.driver.ProcessRequest(self)

  def do_POST(self):
    if 'content-length' in self.headers:
      self.body = self.rfile.read(int(self.headers['content-length']))
    self.driver.ProcessRequest(self)

  def log_message(self, format, *args):
    sys.stdout.write((format % args) + '\n')


class ThreadingHTTPServer (SocketServer.ThreadingMixIn,
                           BaseHTTPServer.HTTPServer):
  pass


def CreateServer(driver, port, certificate_directory=None):
  if not certificate_directory:
    certificate_directory = os.path.join(os.getcwd(), 'certificates')
  ProxyHandler.driver = driver
  ProxyHandler.certificate_directory = certificate_directory
  return ThreadingHTTPServer(('', port), ProxyHandler)