summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorricea@chromium.org <ricea@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2014-06-05 11:11:15 +0000
committerricea@chromium.org <ricea@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2014-06-05 11:11:15 +0000
commita6244952c7f13fce9e137e20d27f6700db58ae9f (patch)
tree316edc923e04ef79ced0309249914f31270c12b1
parente1c913c665472a11722a2e657c54dc3830d85378 (diff)
downloadchromium_src-a6244952c7f13fce9e137e20d27f6700db58ae9f.zip
chromium_src-a6244952c7f13fce9e137e20d27f6700db58ae9f.tar.gz
chromium_src-a6244952c7f13fce9e137e20d27f6700db58ae9f.tar.bz2
Support recovery from SSL errors.
Previously, the new WebSocket implementation was unable to handle sites with self-signed certificates and other cases where the user had overridden certificate errors. Add code to support this case. This requires adding infrastructure to pass the SSL error back up to the content layer which knows how to handle it. It also requires that the ID of the frame be known, so an extra parameter has been added to the WebSocketHostMsg_AddChannelRequest IPC message. BUG=364361 Review URL: https://codereview.chromium.org/304093003 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@275066 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r--content/browser/renderer_host/websocket_dispatcher_host.h2
-rw-r--r--content/browser/renderer_host/websocket_dispatcher_host_unittest.cc15
-rw-r--r--content/browser/renderer_host/websocket_host.cc119
-rw-r--r--content/browser/renderer_host/websocket_host.h11
-rw-r--r--content/child/websocket_bridge.cc7
-rw-r--r--content/common/websocket_messages.h5
-rw-r--r--net/websockets/websocket_channel.cc17
-rw-r--r--net/websockets/websocket_channel.h9
-rw-r--r--net/websockets/websocket_channel_test.cc56
-rw-r--r--net/websockets/websocket_event_interface.h29
-rw-r--r--net/websockets/websocket_handshake_stream_create_helper_test.cc5
-rw-r--r--net/websockets/websocket_stream.cc59
-rw-r--r--net/websockets/websocket_stream.h10
-rw-r--r--net/websockets/websocket_stream_test.cc71
-rw-r--r--net/websockets/websocket_test_util.cc28
-rw-r--r--net/websockets/websocket_test_util.h34
16 files changed, 430 insertions, 47 deletions
diff --git a/content/browser/renderer_host/websocket_dispatcher_host.h b/content/browser/renderer_host/websocket_dispatcher_host.h
index b95b580..4179b71 100644
--- a/content/browser/renderer_host/websocket_dispatcher_host.h
+++ b/content/browser/renderer_host/websocket_dispatcher_host.h
@@ -111,6 +111,8 @@ class CONTENT_EXPORT WebSocketDispatcherHost : public BrowserMessageFilter {
// Returns whether the associated renderer process can read raw cookies.
bool CanReadRawCookies() const;
+ int render_process_id() const { return process_id_; }
+
private:
typedef base::hash_map<int, WebSocketHost*> WebSocketHostTable;
diff --git a/content/browser/renderer_host/websocket_dispatcher_host_unittest.cc b/content/browser/renderer_host/websocket_dispatcher_host_unittest.cc
index 27d5eee..e1506d9 100644
--- a/content/browser/renderer_host/websocket_dispatcher_host_unittest.cc
+++ b/content/browser/renderer_host/websocket_dispatcher_host_unittest.cc
@@ -20,6 +20,9 @@
namespace content {
namespace {
+// This number is unlikely to occur by chance.
+static const int kMagicRenderProcessId = 506116062;
+
// A mock of WebsocketHost which records received messages.
class MockWebSocketHost : public WebSocketHost {
public:
@@ -43,7 +46,7 @@ class WebSocketDispatcherHostTest : public ::testing::Test {
public:
WebSocketDispatcherHostTest() {
dispatcher_host_ = new WebSocketDispatcherHost(
- 0,
+ kMagicRenderProcessId,
base::Bind(&WebSocketDispatcherHostTest::OnGetRequestContext,
base::Unretained(this)),
base::Bind(&WebSocketDispatcherHostTest::CreateWebSocketHost,
@@ -81,14 +84,19 @@ TEST_F(WebSocketDispatcherHostTest, UnrelatedMessage) {
EXPECT_FALSE(dispatcher_host_->OnMessageReceived(message));
}
+TEST_F(WebSocketDispatcherHostTest, RenderProcessIdGetter) {
+ EXPECT_EQ(kMagicRenderProcessId, dispatcher_host_->render_process_id());
+}
+
TEST_F(WebSocketDispatcherHostTest, AddChannelRequest) {
int routing_id = 123;
GURL socket_url("ws://example.com/test");
std::vector<std::string> requested_protocols;
requested_protocols.push_back("hello");
url::Origin origin("http://example.com/test");
+ int render_frame_id = -2;
WebSocketHostMsg_AddChannelRequest message(
- routing_id, socket_url, requested_protocols, origin);
+ routing_id, socket_url, requested_protocols, origin, render_frame_id);
ASSERT_TRUE(dispatcher_host_->OnMessageReceived(message));
@@ -120,8 +128,9 @@ TEST_F(WebSocketDispatcherHostTest, SendFrame) {
std::vector<std::string> requested_protocols;
requested_protocols.push_back("hello");
url::Origin origin("http://example.com/test");
+ int render_frame_id = -2;
WebSocketHostMsg_AddChannelRequest add_channel_message(
- routing_id, socket_url, requested_protocols, origin);
+ routing_id, socket_url, requested_protocols, origin, render_frame_id);
ASSERT_TRUE(dispatcher_host_->OnMessageReceived(add_channel_message));
diff --git a/content/browser/renderer_host/websocket_host.cc b/content/browser/renderer_host/websocket_host.cc
index ed398f8d..7f63918 100644
--- a/content/browser/renderer_host/websocket_host.cc
+++ b/content/browser/renderer_host/websocket_host.cc
@@ -5,13 +5,17 @@
#include "content/browser/renderer_host/websocket_host.h"
#include "base/basictypes.h"
+#include "base/memory/weak_ptr.h"
#include "base/strings/string_util.h"
#include "content/browser/renderer_host/websocket_dispatcher_host.h"
+#include "content/browser/ssl/ssl_error_handler.h"
+#include "content/browser/ssl/ssl_manager.h"
#include "content/common/websocket_messages.h"
#include "ipc/ipc_message_macros.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_util.h"
+#include "net/ssl/ssl_info.h"
#include "net/websockets/websocket_channel.h"
#include "net/websockets/websocket_event_interface.h"
#include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode
@@ -80,7 +84,9 @@ ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) {
// renderer or child process via WebSocketDispatcherHost.
class WebSocketEventHandler : public net::WebSocketEventInterface {
public:
- WebSocketEventHandler(WebSocketDispatcherHost* dispatcher, int routing_id);
+ WebSocketEventHandler(WebSocketDispatcherHost* dispatcher,
+ int routing_id,
+ int render_frame_id);
virtual ~WebSocketEventHandler();
// net::WebSocketEventInterface implementation
@@ -102,18 +108,50 @@ class WebSocketEventHandler : public net::WebSocketEventInterface {
scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE;
virtual ChannelState OnFinishOpeningHandshake(
scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE;
+ virtual ChannelState OnSSLCertificateError(
+ scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
+ const GURL& url,
+ const net::SSLInfo& ssl_info,
+ bool fatal) OVERRIDE;
private:
+ class SSLErrorHandlerDelegate : public SSLErrorHandler::Delegate {
+ public:
+ SSLErrorHandlerDelegate(
+ scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks);
+ virtual ~SSLErrorHandlerDelegate();
+
+ base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr();
+
+ // SSLErrorHandler::Delegate methods
+ virtual void CancelSSLRequest(const GlobalRequestID& id,
+ int error,
+ const net::SSLInfo* ssl_info) OVERRIDE;
+ virtual void ContinueSSLRequest(const GlobalRequestID& id) OVERRIDE;
+
+ private:
+ scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_;
+ base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate);
+ };
+
WebSocketDispatcherHost* const dispatcher_;
const int routing_id_;
+ const int render_frame_id_;
+ scoped_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_;
DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler);
};
WebSocketEventHandler::WebSocketEventHandler(
WebSocketDispatcherHost* dispatcher,
- int routing_id)
- : dispatcher_(dispatcher), routing_id_(routing_id) {}
+ int routing_id,
+ int render_frame_id)
+ : dispatcher_(dispatcher),
+ routing_id_(routing_id),
+ render_frame_id_(render_frame_id) {
+}
WebSocketEventHandler::~WebSocketEventHandler() {
DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_;
@@ -227,18 +265,67 @@ ChannelState WebSocketEventHandler::OnFinishOpeningHandshake(
response_to_pass));
}
+ChannelState WebSocketEventHandler::OnSSLCertificateError(
+ scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
+ const GURL& url,
+ const net::SSLInfo& ssl_info,
+ bool fatal) {
+ DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError"
+ << " routing_id=" << routing_id_ << " url=" << url.spec()
+ << " cert_status=" << ssl_info.cert_status << " fatal=" << fatal;
+ ssl_error_handler_delegate_.reset(
+ new SSLErrorHandlerDelegate(callbacks.Pass()));
+ // We don't need request_id to be unique so just make a fake one.
+ GlobalRequestID request_id(-1, -1);
+ SSLManager::OnSSLCertificateError(ssl_error_handler_delegate_->GetWeakPtr(),
+ request_id,
+ ResourceType::SUB_RESOURCE,
+ url,
+ dispatcher_->render_process_id(),
+ render_frame_id_,
+ ssl_info,
+ fatal);
+ // The above method is always asynchronous.
+ return WebSocketEventInterface::CHANNEL_ALIVE;
+}
+
+WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate(
+ scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks)
+ : callbacks_(callbacks.Pass()), weak_ptr_factory_(this) {}
+
+WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {}
+
+base::WeakPtr<SSLErrorHandler::Delegate>
+WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() {
+ return weak_ptr_factory_.GetWeakPtr();
+}
+
+void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest(
+ const GlobalRequestID& id,
+ int error,
+ const net::SSLInfo* ssl_info) {
+ DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest"
+ << " error=" << error
+ << " cert_status=" << (ssl_info ? ssl_info->cert_status
+ : static_cast<net::CertStatus>(-1));
+ callbacks_->CancelSSLRequest(error, ssl_info);
+}
+
+void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest(
+ const GlobalRequestID& id) {
+ DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest";
+ callbacks_->ContinueSSLRequest();
+}
+
} // namespace
WebSocketHost::WebSocketHost(int routing_id,
WebSocketDispatcherHost* dispatcher,
net::URLRequestContext* url_request_context)
- : routing_id_(routing_id) {
+ : dispatcher_(dispatcher),
+ url_request_context_(url_request_context),
+ routing_id_(routing_id) {
DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id;
-
- scoped_ptr<net::WebSocketEventInterface> event_interface(
- new WebSocketEventHandler(dispatcher, routing_id));
- channel_.reset(
- new net::WebSocketChannel(event_interface.Pass(), url_request_context));
}
WebSocketHost::~WebSocketHost() {}
@@ -258,15 +345,20 @@ bool WebSocketHost::OnMessageReceived(const IPC::Message& message) {
void WebSocketHost::OnAddChannelRequest(
const GURL& socket_url,
const std::vector<std::string>& requested_protocols,
- const url::Origin& origin) {
+ const url::Origin& origin,
+ int render_frame_id) {
DVLOG(3) << "WebSocketHost::OnAddChannelRequest"
<< " routing_id=" << routing_id_ << " socket_url=\"" << socket_url
<< "\" requested_protocols=\""
<< JoinString(requested_protocols, ", ") << "\" origin=\""
<< origin.string() << "\"";
- channel_->SendAddChannelRequest(
- socket_url, requested_protocols, origin);
+ DCHECK(!channel_);
+ scoped_ptr<net::WebSocketEventInterface> event_interface(
+ new WebSocketEventHandler(dispatcher_, routing_id_, render_frame_id));
+ channel_.reset(
+ new net::WebSocketChannel(event_interface.Pass(), url_request_context_));
+ channel_->SendAddChannelRequest(socket_url, requested_protocols, origin);
}
void WebSocketHost::OnSendFrame(bool fin,
@@ -276,6 +368,7 @@ void WebSocketHost::OnSendFrame(bool fin,
<< " routing_id=" << routing_id_ << " fin=" << fin
<< " type=" << type << " data is " << data.size() << " bytes";
+ DCHECK(channel_);
channel_->SendFrame(fin, MessageTypeToOpCode(type), data);
}
@@ -283,6 +376,7 @@ void WebSocketHost::OnFlowControl(int64 quota) {
DVLOG(3) << "WebSocketHost::OnFlowControl"
<< " routing_id=" << routing_id_ << " quota=" << quota;
+ DCHECK(channel_);
channel_->SendFlowControl(quota);
}
@@ -293,6 +387,7 @@ void WebSocketHost::OnDropChannel(bool was_clean,
<< " routing_id=" << routing_id_ << " was_clean=" << was_clean
<< " code=" << code << " reason=\"" << reason << "\"";
+ DCHECK(channel_);
// TODO(yhirano): Handle |was_clean| appropriately.
channel_->StartClosingHandshake(code, reason);
}
diff --git a/content/browser/renderer_host/websocket_host.h b/content/browser/renderer_host/websocket_host.h
index 148d6ed..21c77dc 100644
--- a/content/browser/renderer_host/websocket_host.h
+++ b/content/browser/renderer_host/websocket_host.h
@@ -50,7 +50,8 @@ class CONTENT_EXPORT WebSocketHost {
void OnAddChannelRequest(const GURL& socket_url,
const std::vector<std::string>& requested_protocols,
- const url::Origin& origin);
+ const url::Origin& origin,
+ int render_frame_id);
void OnSendFrame(bool fin,
WebSocketMessageType type,
@@ -63,8 +64,14 @@ class CONTENT_EXPORT WebSocketHost {
// The channel we use to send events to the network.
scoped_ptr<net::WebSocketChannel> channel_;
+ // The WebSocketHostDispatcher that created this object.
+ WebSocketDispatcherHost* const dispatcher_;
+
+ // The URL request context for the channel.
+ net::URLRequestContext* const url_request_context_;
+
// The ID used to route messages.
- int routing_id_;
+ const int routing_id_;
DISALLOW_COPY_AND_ASSIGN(WebSocketHost);
};
diff --git a/content/child/websocket_bridge.cc b/content/child/websocket_bridge.cc
index a213bd7..5b32702 100644
--- a/content/child/websocket_bridge.cc
+++ b/content/child/websocket_bridge.cc
@@ -227,11 +227,8 @@ void WebSocketBridge::connect(
<< JoinString(protocols_to_pass, ", ") << "), "
<< origin_to_pass.string() << ")";
- ChildThread::current()->Send(
- new WebSocketHostMsg_AddChannelRequest(channel_id_,
- url,
- protocols_to_pass,
- origin_to_pass));
+ ChildThread::current()->Send(new WebSocketHostMsg_AddChannelRequest(
+ channel_id_, url, protocols_to_pass, origin_to_pass, render_frame_id_));
}
void WebSocketBridge::send(bool fin,
diff --git a/content/common/websocket_messages.h b/content/common/websocket_messages.h
index 0ece66a..d42b38b 100644
--- a/content/common/websocket_messages.h
+++ b/content/common/websocket_messages.h
@@ -58,10 +58,11 @@ IPC_STRUCT_TRAITS_END()
// The browser process will not send |channel_id| as-is to the remote server; it
// will try to use a short id on the wire. This saves the renderer from
// having to try to choose the ids cleverly.
-IPC_MESSAGE_ROUTED3(WebSocketHostMsg_AddChannelRequest,
+IPC_MESSAGE_ROUTED4(WebSocketHostMsg_AddChannelRequest,
GURL /* socket_url */,
std::vector<std::string> /* requested_protocols */,
- url::Origin /* origin */)
+ url::Origin /* origin */,
+ int /* render_frame_id */)
// WebSocket messages sent from the browser to the renderer.
diff --git a/net/websockets/websocket_channel.cc b/net/websockets/websocket_channel.cc
index 47114f8..c27f8dd 100644
--- a/net/websockets/websocket_channel.cc
+++ b/net/websockets/websocket_channel.cc
@@ -176,6 +176,15 @@ class WebSocketChannel::ConnectDelegate
creator_->OnFinishOpeningHandshake(response.Pass());
}
+ virtual void OnSSLCertificateError(
+ scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks>
+ ssl_error_callbacks,
+ const SSLInfo& ssl_info,
+ bool fatal) OVERRIDE {
+ creator_->OnSSLCertificateError(
+ ssl_error_callbacks.Pass(), ssl_info, fatal);
+ }
+
private:
// A pointer to the WebSocketChannel that created this object. There is no
// danger of this pointer being stale, because deleting the WebSocketChannel
@@ -576,6 +585,14 @@ void WebSocketChannel::OnConnectFailure(const std::string& message) {
// |this| has been deleted.
}
+void WebSocketChannel::OnSSLCertificateError(
+ scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,
+ const SSLInfo& ssl_info,
+ bool fatal) {
+ AllowUnused(event_interface_->OnSSLCertificateError(
+ ssl_error_callbacks.Pass(), socket_url_, ssl_info, fatal));
+}
+
void WebSocketChannel::OnStartOpeningHandshake(
scoped_ptr<WebSocketHandshakeRequestInfo> request) {
DCHECK(!notification_sender_->handshake_request_info());
diff --git a/net/websockets/websocket_channel.h b/net/websockets/websocket_channel.h
index 5c7a6cf..6d5640e 100644
--- a/net/websockets/websocket_channel.h
+++ b/net/websockets/websocket_channel.h
@@ -204,6 +204,15 @@ class NET_EXPORT WebSocketChannel {
// failure to the event interface. May delete |this|.
void OnConnectFailure(const std::string& message);
+ // SSL certificate error callback from
+ // WebSocketStream::CreateAndConnectStream(). Forwards the request to the
+ // event interface.
+ void OnSSLCertificateError(
+ scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks>
+ ssl_error_callbacks,
+ const SSLInfo& ssl_info,
+ bool fatal);
+
// Posts a task that sends pending notifications relating WebSocket Opening
// Handshake to the renderer.
void ScheduleOpeningHandshakeNotification();
diff --git a/net/websockets/websocket_channel_test.cc b/net/websockets/websocket_channel_test.cc
index 464c30e..4a8f119 100644
--- a/net/websockets/websocket_channel_test.cc
+++ b/net/websockets/websocket_channel_test.cc
@@ -97,6 +97,7 @@ using ::testing::AnyNumber;
using ::testing::DefaultValue;
using ::testing::InSequence;
using ::testing::MockFunction;
+using ::testing::NotNull;
using ::testing::Return;
using ::testing::SaveArg;
using ::testing::StrictMock;
@@ -171,9 +172,21 @@ class MockWebSocketEventInterface : public WebSocketEventInterface {
OnFinishOpeningHandshakeCalled();
return CHANNEL_ALIVE;
}
+ virtual ChannelState OnSSLCertificateError(
+ scoped_ptr<SSLErrorCallbacks> ssl_error_callbacks,
+ const GURL& url,
+ const SSLInfo& ssl_info,
+ bool fatal) OVERRIDE {
+ OnSSLCertificateErrorCalled(
+ ssl_error_callbacks.get(), url, ssl_info, fatal);
+ return CHANNEL_ALIVE;
+ }
MOCK_METHOD0(OnStartOpeningHandshakeCalled, void()); // NOLINT
MOCK_METHOD0(OnFinishOpeningHandshakeCalled, void()); // NOLINT
+ MOCK_METHOD4(
+ OnSSLCertificateErrorCalled,
+ void(SSLErrorCallbacks*, const GURL&, const SSLInfo&, bool)); // NOLINT
};
// This fake EventInterface is for tests which need a WebSocketEventInterface
@@ -210,6 +223,13 @@ class FakeWebSocketEventInterface : public WebSocketEventInterface {
scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE {
return CHANNEL_ALIVE;
}
+ virtual ChannelState OnSSLCertificateError(
+ scoped_ptr<SSLErrorCallbacks> ssl_error_callbacks,
+ const GURL& url,
+ const SSLInfo& ssl_info,
+ bool fatal) OVERRIDE {
+ return CHANNEL_ALIVE;
+ }
};
// This fake WebSocketStream is for tests that require a WebSocketStream but are
@@ -713,6 +733,13 @@ std::vector<char> AsVector(const std::string& s) {
return std::vector<char>(s.begin(), s.end());
}
+class FakeSSLErrorCallbacks
+ : public WebSocketEventInterface::SSLErrorCallbacks {
+ public:
+ virtual void CancelSSLRequest(int error, const SSLInfo* ssl_info) OVERRIDE {}
+ virtual void ContinueSSLRequest() OVERRIDE {}
+};
+
// Base class for all test fixtures.
class WebSocketChannelTest : public ::testing::Test {
protected:
@@ -797,6 +824,7 @@ enum EventInterfaceCall {
EVENT_ON_DROP_CHANNEL = 0x20,
EVENT_ON_START_OPENING_HANDSHAKE = 0x40,
EVENT_ON_FINISH_OPENING_HANDSHAKE = 0x80,
+ EVENT_ON_SSL_CERTIFICATE_ERROR = 0x100,
};
class WebSocketChannelDeletingTest : public WebSocketChannelTest {
@@ -818,7 +846,8 @@ class WebSocketChannelDeletingTest : public WebSocketChannelTest {
EVENT_ON_FAIL_CHANNEL |
EVENT_ON_DROP_CHANNEL |
EVENT_ON_START_OPENING_HANDSHAKE |
- EVENT_ON_FINISH_OPENING_HANDSHAKE) {}
+ EVENT_ON_FINISH_OPENING_HANDSHAKE |
+ EVENT_ON_SSL_CERTIFICATE_ERROR) {}
// Create a ChannelDeletingFakeWebSocketEventInterface. Defined out-of-line to
// avoid circular dependency.
virtual scoped_ptr<WebSocketEventInterface> CreateEventInterface() OVERRIDE;
@@ -877,6 +906,13 @@ class ChannelDeletingFakeWebSocketEventInterface
scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE {
return fixture_->DeleteIfDeleting(EVENT_ON_FINISH_OPENING_HANDSHAKE);
}
+ virtual ChannelState OnSSLCertificateError(
+ scoped_ptr<SSLErrorCallbacks> ssl_error_callbacks,
+ const GURL& url,
+ const SSLInfo& ssl_info,
+ bool fatal) OVERRIDE {
+ return fixture_->DeleteIfDeleting(EVENT_ON_SSL_CERTIFICATE_ERROR);
+ }
private:
// A pointer to the test fixture. Owned by the test harness; this object will
@@ -3209,6 +3245,24 @@ TEST_F(WebSocketChannelEventInterfaceTest, DataFramesNonEmptyOrFinal) {
CreateChannelAndConnectSuccessfully();
}
+// Calls to OnSSLCertificateError() must be passed through to the event
+// interface with the correct URL attached.
+TEST_F(WebSocketChannelEventInterfaceTest, OnSSLCertificateErrorCalled) {
+ const GURL wss_url("wss://example.com/sslerror");
+ connect_data_.socket_url = wss_url;
+ const SSLInfo ssl_info;
+ const bool fatal = true;
+ scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> fake_callbacks(
+ new FakeSSLErrorCallbacks);
+
+ EXPECT_CALL(*event_interface_,
+ OnSSLCertificateErrorCalled(NotNull(), wss_url, _, fatal));
+
+ CreateChannelAndConnect();
+ connect_data_.creator.connect_delegate->OnSSLCertificateError(
+ fake_callbacks.Pass(), ssl_info, fatal);
+}
+
// If we receive another frame after Close, it is not valid. It is not
// completely clear what behaviour is required from the standard in this case,
// but the current implementation fails the connection. Since a Close has
diff --git a/net/websockets/websocket_event_interface.h b/net/websockets/websocket_event_interface.h
index 923581a..d32a7c1 100644
--- a/net/websockets/websocket_event_interface.h
+++ b/net/websockets/websocket_event_interface.h
@@ -12,8 +12,11 @@
#include "base/compiler_specific.h" // for WARN_UNUSED_RESULT
#include "net/base/net_export.h"
+class GURL;
+
namespace net {
+class SSLInfo;
struct WebSocketHandshakeRequestInfo;
struct WebSocketHandshakeResponseInfo;
@@ -99,6 +102,32 @@ class NET_EXPORT WebSocketEventInterface {
scoped_ptr<WebSocketHandshakeResponseInfo> response)
WARN_UNUSED_RESULT = 0;
+ // Callbacks to be used in response to a call to OnSSLCertificateError. Very
+ // similar to content::SSLErrorHandler::Delegate (which we can't use directly
+ // due to layering constraints).
+ class NET_EXPORT SSLErrorCallbacks {
+ public:
+ virtual ~SSLErrorCallbacks() {}
+
+ // Cancels the SSL response in response to the error.
+ virtual void CancelSSLRequest(int error, const SSLInfo* ssl_info) = 0;
+
+ // Continue with the SSL connection despite the error.
+ virtual void ContinueSSLRequest() = 0;
+ };
+
+ // Called on SSL Certificate Error during the SSL handshake. Should result in
+ // a call to either ssl_error_callbacks->ContinueSSLRequest() or
+ // ssl_error_callbacks->CancelSSLRequest(). Normally the implementation of
+ // this method will delegate to content::SSLManager::OnSSLCertificateError to
+ // make the actual decision. The callbacks must not be called after the
+ // WebSocketChannel has been destroyed.
+ virtual ChannelState OnSSLCertificateError(
+ scoped_ptr<SSLErrorCallbacks> ssl_error_callbacks,
+ const GURL& url,
+ const SSLInfo& ssl_info,
+ bool fatal) WARN_UNUSED_RESULT = 0;
+
protected:
WebSocketEventInterface() {}
diff --git a/net/websockets/websocket_handshake_stream_create_helper_test.cc b/net/websockets/websocket_handshake_stream_create_helper_test.cc
index 652f1bc..b5ec6fb 100644
--- a/net/websockets/websocket_handshake_stream_create_helper_test.cc
+++ b/net/websockets/websocket_handshake_stream_create_helper_test.cc
@@ -67,6 +67,11 @@ class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
scoped_ptr<WebSocketHandshakeRequestInfo> request) OVERRIDE {}
virtual void OnFinishOpeningHandshake(
scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE {}
+ virtual void OnSSLCertificateError(
+ scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks>
+ ssl_error_callbacks,
+ const SSLInfo& ssl_info,
+ bool fatal) OVERRIDE {}
};
class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test {
diff --git a/net/websockets/websocket_stream.cc b/net/websockets/websocket_stream.cc
index 546e01b..9880ea8 100644
--- a/net/websockets/websocket_stream.cc
+++ b/net/websockets/websocket_stream.cc
@@ -13,6 +13,7 @@
#include "net/url_request/url_request.h"
#include "net/url_request/url_request_context.h"
#include "net/websockets/websocket_errors.h"
+#include "net/websockets/websocket_event_interface.h"
#include "net/websockets/websocket_handshake_constants.h"
#include "net/websockets/websocket_handshake_stream_base.h"
#include "net/websockets/websocket_handshake_stream_create_helper.h"
@@ -42,6 +43,17 @@ class Delegate : public URLRequest::Delegate {
}
// Implementation of URLRequest::Delegate methods.
+ virtual void OnReceivedRedirect(URLRequest* request,
+ const GURL& new_url,
+ bool* defer_redirect) OVERRIDE {
+ // HTTP status codes returned by HttpStreamParser are filtered by
+ // WebSocketBasicHandshakeStream, and only 101, 401 and 407 are permitted
+ // back up the stack to HttpNetworkTransaction. In particular, redirect
+ // codes are never allowed, and so URLRequest never sees a redirect on a
+ // WebSocket request.
+ NOTREACHED();
+ }
+
virtual void OnResponseStarted(URLRequest* request) OVERRIDE;
virtual void OnAuthRequired(URLRequest* request,
@@ -125,6 +137,10 @@ class StreamRequestImpl : public WebSocketStreamRequest {
connect_delegate_->OnFailure(failure_message);
}
+ WebSocketStream::ConnectDelegate* connect_delegate() const {
+ return connect_delegate_.get();
+ }
+
private:
// |delegate_| needs to be declared before |url_request_| so that it gets
// initialised first.
@@ -140,7 +156,35 @@ class StreamRequestImpl : public WebSocketStreamRequest {
WebSocketHandshakeStreamCreateHelper* create_helper_;
};
+class SSLErrorCallbacks : public WebSocketEventInterface::SSLErrorCallbacks {
+ public:
+ explicit SSLErrorCallbacks(URLRequest* url_request)
+ : url_request_(url_request) {}
+
+ virtual void CancelSSLRequest(int error, const SSLInfo* ssl_info) OVERRIDE {
+ if (ssl_info) {
+ url_request_->CancelWithSSLError(error, *ssl_info);
+ } else {
+ url_request_->CancelWithError(error);
+ }
+ }
+
+ virtual void ContinueSSLRequest() OVERRIDE {
+ url_request_->ContinueDespiteLastError();
+ }
+
+ private:
+ URLRequest* url_request_;
+};
+
void Delegate::OnResponseStarted(URLRequest* request) {
+ if (!request->status().is_success()) {
+ DVLOG(3) << "OnResponseStarted (request failed)";
+ owner_->ReportFailure();
+ return;
+ }
+ DVLOG(3) << "OnResponseStarted (response code " << request->GetResponseCode()
+ << ")";
switch (request->GetResponseCode()) {
case HTTP_SWITCHING_PROTOCOLS:
result_ = CONNECTED;
@@ -159,18 +203,29 @@ void Delegate::OnResponseStarted(URLRequest* request) {
void Delegate::OnAuthRequired(URLRequest* request,
AuthChallengeInfo* auth_info) {
+ // This should only be called if credentials are not already stored.
request->CancelAuth();
}
void Delegate::OnCertificateRequested(URLRequest* request,
SSLCertRequestInfo* cert_request_info) {
- request->ContinueWithCertificate(NULL);
+ // This method is called when a client certificate is requested, and the
+ // request context does not already contain a client certificate selection for
+ // the endpoint. In this case, a main frame resource request would pop-up UI
+ // to permit selection of a client certificate, but since WebSockets are
+ // sub-resources they should not pop-up UI and so there is nothing more we can
+ // do.
+ request->Cancel();
}
void Delegate::OnSSLCertificateError(URLRequest* request,
const SSLInfo& ssl_info,
bool fatal) {
- request->Cancel();
+ owner_->connect_delegate()->OnSSLCertificateError(
+ scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks>(
+ new SSLErrorCallbacks(request)),
+ ssl_info,
+ fatal);
}
void Delegate::OnReadCompleted(URLRequest* request, int bytes_read) {
diff --git a/net/websockets/websocket_stream.h b/net/websockets/websocket_stream.h
index d881a37..09f11b2 100644
--- a/net/websockets/websocket_stream.h
+++ b/net/websockets/websocket_stream.h
@@ -14,6 +14,7 @@
#include "base/memory/scoped_vector.h"
#include "net/base/completion_callback.h"
#include "net/base/net_export.h"
+#include "net/websockets/websocket_event_interface.h"
#include "net/websockets/websocket_handshake_request_info.h"
#include "net/websockets/websocket_handshake_response_info.h"
@@ -74,6 +75,15 @@ class NET_EXPORT_PRIVATE WebSocketStream {
// Called when the WebSocket Opening Handshake ends.
virtual void OnFinishOpeningHandshake(
scoped_ptr<WebSocketHandshakeResponseInfo> response) = 0;
+
+ // Called when there is an SSL certificate error. Should call
+ // ssl_error_callbacks->ContinueSSLRequest() or
+ // ssl_error_callbacks->CancelSSLRequest().
+ virtual void OnSSLCertificateError(
+ scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks>
+ ssl_error_callbacks,
+ const SSLInfo& ssl_info,
+ bool fatal) = 0;
};
// Create and connect a WebSocketStream of an appropriate type. The actual
diff --git a/net/websockets/websocket_stream_test.cc b/net/websockets/websocket_stream_test.cc
index 4ea8538..0344cce 100644
--- a/net/websockets/websocket_stream_test.cc
+++ b/net/websockets/websocket_stream_test.cc
@@ -16,10 +16,12 @@
#include "base/run_loop.h"
#include "base/strings/stringprintf.h"
#include "net/base/net_errors.h"
+#include "net/base/test_data_directory.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/socket_test_util.h"
+#include "net/test/cert_test_util.h"
#include "net/url_request/url_request_test_util.h"
#include "net/websockets/websocket_basic_handshake_stream.h"
#include "net/websockets/websocket_frame.h"
@@ -79,7 +81,7 @@ class DeterministicKeyWebSocketHandshakeStreamCreateHelper
class WebSocketStreamCreateTest : public ::testing::Test {
public:
- WebSocketStreamCreateTest(): has_failed_(false) {}
+ WebSocketStreamCreateTest() : has_failed_(false), ssl_fatal_(false) {}
void CreateAndConnectCustomResponse(
const std::string& socket_url,
@@ -116,7 +118,7 @@ class WebSocketStreamCreateTest : public ::testing::Test {
const std::vector<std::string>& sub_protocols,
const std::string& origin,
scoped_ptr<DeterministicSocketData> socket_data) {
- url_request_context_host_.SetRawExpectations(socket_data.Pass());
+ url_request_context_host_.AddRawExpectations(socket_data.Pass());
CreateAndConnectStream(socket_url, sub_protocols, origin);
}
@@ -125,6 +127,12 @@ class WebSocketStreamCreateTest : public ::testing::Test {
void CreateAndConnectStream(const std::string& socket_url,
const std::vector<std::string>& sub_protocols,
const std::string& origin) {
+ for (size_t i = 0; i < ssl_data_.size(); ++i) {
+ scoped_ptr<SSLSocketDataProvider> ssl_data(ssl_data_[i]);
+ ssl_data_[i] = NULL;
+ url_request_context_host_.AddSSLSocketDataProvider(ssl_data.Pass());
+ }
+ ssl_data_.clear();
scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate(
new TestConnectDelegate(this));
WebSocketStream::ConnectDelegate* delegate = connect_delegate.get();
@@ -175,6 +183,15 @@ class WebSocketStreamCreateTest : public ::testing::Test {
ADD_FAILURE();
owner_->response_info_ = response.Pass();
}
+ virtual void OnSSLCertificateError(
+ scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks>
+ ssl_error_callbacks,
+ const SSLInfo& ssl_info,
+ bool fatal) OVERRIDE {
+ owner_->ssl_error_callbacks_ = ssl_error_callbacks.Pass();
+ owner_->ssl_info_ = ssl_info;
+ owner_->ssl_fatal_ = fatal;
+ }
private:
WebSocketStreamCreateTest* owner_;
@@ -189,6 +206,10 @@ class WebSocketStreamCreateTest : public ::testing::Test {
bool has_failed_;
scoped_ptr<WebSocketHandshakeRequestInfo> request_info_;
scoped_ptr<WebSocketHandshakeResponseInfo> response_info_;
+ scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks_;
+ SSLInfo ssl_info_;
+ bool ssl_fatal_;
+ ScopedVector<SSLSocketDataProvider> ssl_data_;
};
// There are enough tests of the Sec-WebSocket-Extensions header that they
@@ -1032,6 +1053,48 @@ TEST_F(WebSocketStreamCreateTest, NoResponse) {
failure_message());
}
+TEST_F(WebSocketStreamCreateTest, SelfSignedCertificateFailure) {
+ ssl_data_.push_back(
+ new SSLSocketDataProvider(ASYNC, ERR_CERT_AUTHORITY_INVALID));
+ ssl_data_[0]->cert =
+ ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der");
+ ASSERT_TRUE(ssl_data_[0]->cert);
+ scoped_ptr<DeterministicSocketData> raw_socket_data(
+ new DeterministicSocketData(NULL, 0, NULL, 0));
+ CreateAndConnectRawExpectations("wss://localhost/",
+ NoSubProtocols(),
+ "http://localhost",
+ raw_socket_data.Pass());
+ RunUntilIdle();
+ EXPECT_FALSE(has_failed());
+ ASSERT_TRUE(ssl_error_callbacks_);
+ ssl_error_callbacks_->CancelSSLRequest(ERR_CERT_AUTHORITY_INVALID,
+ &ssl_info_);
+ RunUntilIdle();
+ EXPECT_TRUE(has_failed());
+}
+
+TEST_F(WebSocketStreamCreateTest, SelfSignedCertificateSuccess) {
+ scoped_ptr<SSLSocketDataProvider> ssl_data(
+ new SSLSocketDataProvider(ASYNC, ERR_CERT_AUTHORITY_INVALID));
+ ssl_data->cert =
+ ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der");
+ ASSERT_TRUE(ssl_data->cert);
+ ssl_data_.push_back(ssl_data.release());
+ ssl_data.reset(new SSLSocketDataProvider(ASYNC, OK));
+ ssl_data_.push_back(ssl_data.release());
+ url_request_context_host_.AddRawExpectations(
+ make_scoped_ptr(new DeterministicSocketData(NULL, 0, NULL, 0)));
+ CreateAndConnectStandard(
+ "wss://localhost/", "/", NoSubProtocols(), "http://localhost", "", "");
+ RunUntilIdle();
+ ASSERT_TRUE(ssl_error_callbacks_);
+ ssl_error_callbacks_->ContinueSSLRequest();
+ RunUntilIdle();
+ EXPECT_FALSE(has_failed());
+ EXPECT_TRUE(stream_);
+}
+
TEST_F(WebSocketStreamCreateUMATest, Incomplete) {
const std::string name("Net.WebSocket.HandshakeResult");
scoped_ptr<base::HistogramSamples> original(GetSamples(name));
@@ -1107,9 +1170,9 @@ TEST_F(WebSocketStreamCreateUMATest, Failed) {
if (original) {
samples->Subtract(*original); // Cancel the original values.
}
- EXPECT_EQ(0, samples->GetCount(INCOMPLETE));
+ EXPECT_EQ(1, samples->GetCount(INCOMPLETE));
EXPECT_EQ(0, samples->GetCount(CONNECTED));
- EXPECT_EQ(1, samples->GetCount(FAILED));
+ EXPECT_EQ(0, samples->GetCount(FAILED));
}
} // namespace
diff --git a/net/websockets/websocket_test_util.cc b/net/websockets/websocket_test_util.cc
index 7605780..bfa8980 100644
--- a/net/websockets/websocket_test_util.cc
+++ b/net/websockets/websocket_test_util.cc
@@ -8,6 +8,7 @@
#include <vector>
#include "base/basictypes.h"
+#include "base/memory/scoped_vector.h"
#include "base/stl_util.h"
#include "base/strings/stringprintf.h"
#include "net/socket/socket_test_util.h"
@@ -72,7 +73,8 @@ struct WebSocketDeterministicMockClientSocketFactoryMaker::Detail {
std::string return_to_read;
std::vector<MockRead> reads;
MockWrite write;
- scoped_ptr<DeterministicSocketData> data;
+ ScopedVector<DeterministicSocketData> socket_data_vector;
+ ScopedVector<SSLSocketDataProvider> ssl_socket_data_vector;
DeterministicMockClientSocketFactory factory;
};
@@ -117,13 +119,20 @@ void WebSocketDeterministicMockClientSocketFactoryMaker::SetExpectations(
1));
socket_data->set_connect_data(MockConnect(SYNCHRONOUS, OK));
socket_data->SetStop(sequence);
- SetRawExpectations(socket_data.Pass());
+ AddRawExpectations(socket_data.Pass());
}
-void WebSocketDeterministicMockClientSocketFactoryMaker::SetRawExpectations(
+void WebSocketDeterministicMockClientSocketFactoryMaker::AddRawExpectations(
scoped_ptr<DeterministicSocketData> socket_data) {
- detail_->data = socket_data.Pass();
- detail_->factory.AddSocketDataProvider(detail_->data.get());
+ detail_->factory.AddSocketDataProvider(socket_data.get());
+ detail_->socket_data_vector.push_back(socket_data.release());
+}
+
+void
+WebSocketDeterministicMockClientSocketFactoryMaker::AddSSLSocketDataProvider(
+ scoped_ptr<SSLSocketDataProvider> ssl_socket_data) {
+ detail_->factory.AddSSLSocketDataProvider(ssl_socket_data.get());
+ detail_->ssl_socket_data_vector.push_back(ssl_socket_data.release());
}
WebSocketTestURLRequestContextHost::WebSocketTestURLRequestContextHost()
@@ -133,9 +142,14 @@ WebSocketTestURLRequestContextHost::WebSocketTestURLRequestContextHost()
WebSocketTestURLRequestContextHost::~WebSocketTestURLRequestContextHost() {}
-void WebSocketTestURLRequestContextHost::SetRawExpectations(
+void WebSocketTestURLRequestContextHost::AddRawExpectations(
scoped_ptr<DeterministicSocketData> socket_data) {
- maker_.SetRawExpectations(socket_data.Pass());
+ maker_.AddRawExpectations(socket_data.Pass());
+}
+
+void WebSocketTestURLRequestContextHost::AddSSLSocketDataProvider(
+ scoped_ptr<SSLSocketDataProvider> ssl_socket_data) {
+ maker_.AddSSLSocketDataProvider(ssl_socket_data.Pass());
}
TestURLRequestContext*
diff --git a/net/websockets/websocket_test_util.h b/net/websockets/websocket_test_util.h
index 5dba6cd..2ad86c0 100644
--- a/net/websockets/websocket_test_util.h
+++ b/net/websockets/websocket_test_util.h
@@ -21,10 +21,11 @@ class Origin;
namespace net {
class BoundNetLog;
+class DeterministicMockClientSocketFactory;
class DeterministicSocketData;
class URLRequestContext;
class WebSocketHandshakeStreamCreateHelper;
-class DeterministicMockClientSocketFactory;
+struct SSLSocketDataProvider;
class LinearCongruentialGenerator {
public:
@@ -65,15 +66,26 @@ class WebSocketDeterministicMockClientSocketFactoryMaker {
WebSocketDeterministicMockClientSocketFactoryMaker();
~WebSocketDeterministicMockClientSocketFactoryMaker();
- // The socket created by the factory will expect |expect_written| to be
- // written to the socket, and will respond with |return_to_read|. The test
- // will fail if the expected text is not written, or all the bytes are not
- // read.
+ // Tell the factory to create a socket which expects |expect_written| to be
+ // written, and responds with |return_to_read|. The test will fail if the
+ // expected text is not written, or all the bytes are not read. This adds data
+ // for a new mock-socket using AddRawExpections(), and so can be called
+ // multiple times to queue up multiple mock sockets, but usually in those
+ // cases the lower-level AddRawExpections() interface is more appropriate.
void SetExpectations(const std::string& expect_written,
const std::string& return_to_read);
- // A low-level interface to permit arbitrary expectations to be set.
- void SetRawExpectations(scoped_ptr<DeterministicSocketData> socket_data);
+ // A low-level interface to permit arbitrary expectations to be added. The
+ // mock sockets will be created in the same order that they were added.
+ void AddRawExpectations(scoped_ptr<DeterministicSocketData> socket_data);
+
+ // Allow an SSL socket data provider to be added. You must also supply a mock
+ // transport socket for it to use. If the mock SSL handshake fails then the
+ // mock transport socket will connect but have nothing read or written. If the
+ // mock handshake succeeds then the data from the underlying transport socket
+ // will be passed through unchanged (without encryption).
+ void AddSSLSocketDataProvider(
+ scoped_ptr<SSLSocketDataProvider> ssl_socket_data);
// Call to get a pointer to the factory, which remains owned by this object.
DeterministicMockClientSocketFactory* factory();
@@ -98,9 +110,13 @@ struct WebSocketTestURLRequestContextHost {
maker_.SetExpectations(expect_written, return_to_read);
}
- void SetRawExpectations(scoped_ptr<DeterministicSocketData> socket_data);
+ void AddRawExpectations(scoped_ptr<DeterministicSocketData> socket_data);
+
+ // Allow an SSL socket data provider to be added.
+ void AddSSLSocketDataProvider(
+ scoped_ptr<SSLSocketDataProvider> ssl_socket_data);
- // Call after calling one of SetExpections() or SetRawExpectations(). The
+ // Call after calling one of SetExpections() or AddRawExpectations(). The
// returned pointer remains owned by this object. This should only be called
// once.
TestURLRequestContext* GetURLRequestContext();