diff options
25 files changed, 1051 insertions, 940 deletions
diff --git a/net/base/file_stream_unittest.cc b/net/base/file_stream_unittest.cc index b93d886..7c2ef50 100644 --- a/net/base/file_stream_unittest.cc +++ b/net/base/file_stream_unittest.cc @@ -4,6 +4,7 @@ #include "base/callback.h" #include "base/file_util.h" +#include "base/message_loop.h" #include "base/path_service.h" #include "base/platform_file.h" #include "net/base/file_stream.h" diff --git a/net/base/registry_controlled_domain.cc b/net/base/registry_controlled_domain.cc index 88eccd0..1ee88747 100644 --- a/net/base/registry_controlled_domain.cc +++ b/net/base/registry_controlled_domain.cc @@ -51,12 +51,14 @@ namespace net { -static const int kExceptionRule = 1; -static const int kWildcardRule = 2; +namespace { -RegistryControlledDomainService::RegistryControlledDomainService() - : find_domain_function_(Perfect_Hash::FindDomain) { -} +const int kExceptionRule = 1; +const int kWildcardRule = 2; + +RegistryControlledDomainService* test_instance_; + +} // namespace // static std::string RegistryControlledDomainService::GetDomainAndRegistry( @@ -155,6 +157,34 @@ size_t RegistryControlledDomainService::GetRegistryLength( } // static +RegistryControlledDomainService* RegistryControlledDomainService::GetInstance() +{ + if (test_instance_) + return test_instance_; + + return Singleton<RegistryControlledDomainService>::get(); +} + +RegistryControlledDomainService::RegistryControlledDomainService() + : find_domain_function_(Perfect_Hash::FindDomain) { +} + +// static +RegistryControlledDomainService* RegistryControlledDomainService::SetInstance( + RegistryControlledDomainService* instance) { + RegistryControlledDomainService* old_instance = test_instance_; + test_instance_ = instance; + return old_instance; +} + +// static +void RegistryControlledDomainService::UseFindDomainFunction( + FindDomainPtr function) { + RegistryControlledDomainService* instance = GetInstance(); + instance->find_domain_function_ = function; +} + +// static std::string RegistryControlledDomainService::GetDomainAndRegistryImpl( const std::string& host) { DCHECK(!host.empty()); @@ -261,30 +291,4 @@ size_t RegistryControlledDomainService::GetRegistryLengthImpl( return allow_unknown_registries ? (host.length() - curr_start) : 0; } -static RegistryControlledDomainService* test_instance_; - -// static -RegistryControlledDomainService* RegistryControlledDomainService::SetInstance( - RegistryControlledDomainService* instance) { - RegistryControlledDomainService* old_instance = test_instance_; - test_instance_ = instance; - return old_instance; -} - -// static -RegistryControlledDomainService* RegistryControlledDomainService::GetInstance() -{ - if (test_instance_) - return test_instance_; - - return Singleton<RegistryControlledDomainService>::get(); -} - -// static -void RegistryControlledDomainService::UseFindDomainFunction( - FindDomainPtr function) { - RegistryControlledDomainService* instance = GetInstance(); - instance->find_domain_function_ = function; -} - } // namespace net diff --git a/net/base/test_completion_callback.cc b/net/base/test_completion_callback.cc new file mode 100644 index 0000000..999a71e --- /dev/null +++ b/net/base/test_completion_callback.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2011 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/base/test_completion_callback.h" + +#include "base/message_loop.h" +#include "net/base/net_errors.h" + +TestCompletionCallback::TestCompletionCallback() + : result_(0), + have_result_(false), + waiting_for_result_(false) { +} + +TestCompletionCallback::~TestCompletionCallback() {} + +int TestCompletionCallback::WaitForResult() { + DCHECK(!waiting_for_result_); + while (!have_result_) { + waiting_for_result_ = true; + MessageLoop::current()->Run(); + waiting_for_result_ = false; + } + have_result_ = false; // auto-reset for next callback + return result_; +} + +int TestCompletionCallback::GetResult(int result) { + if (net::ERR_IO_PENDING != result) + return result; + return WaitForResult(); +} + +void TestCompletionCallback::RunWithParams(const Tuple1<int>& params) { + result_ = params.a; + have_result_ = true; + if (waiting_for_result_) + MessageLoop::current()->Quit(); +} diff --git a/net/base/test_completion_callback.h b/net/base/test_completion_callback.h index b4a005b..ba578d4 100644 --- a/net/base/test_completion_callback.h +++ b/net/base/test_completion_callback.h @@ -7,9 +7,6 @@ #pragma once #include "base/callback.h" -#include "base/message_loop.h" -#include "net/base/completion_callback.h" -#include "net/base/net_errors.h" //----------------------------------------------------------------------------- // completion callback helper @@ -24,37 +21,16 @@ // class TestCompletionCallback : public CallbackRunner< Tuple1<int> > { public: - TestCompletionCallback() - : result_(0), - have_result_(false), - waiting_for_result_(false) { - } + TestCompletionCallback(); + virtual ~TestCompletionCallback(); - int WaitForResult() { - DCHECK(!waiting_for_result_); - while (!have_result_) { - waiting_for_result_ = true; - MessageLoop::current()->Run(); - waiting_for_result_ = false; - } - have_result_ = false; // auto-reset for next callback - return result_; - } + int WaitForResult(); - int GetResult(int result) { - if (net::ERR_IO_PENDING != result) - return result; - return WaitForResult(); - } + int GetResult(int result); bool have_result() const { return have_result_; } - virtual void RunWithParams(const Tuple1<int>& params) { - result_ = params.a; - have_result_ = true; - if (waiting_for_result_) - MessageLoop::current()->Quit(); - } + virtual void RunWithParams(const Tuple1<int>& params); private: int result_; diff --git a/net/base/test_completion_callback_unittest.cc b/net/base/test_completion_callback_unittest.cc index d0274e4..1b9d71d 100644 --- a/net/base/test_completion_callback_unittest.cc +++ b/net/base/test_completion_callback_unittest.cc @@ -4,6 +4,9 @@ // Illustrates how to use worker threads that issue completion callbacks +#include "base/logging.h" +#include "base/message_loop.h" +#include "base/task.h" #include "base/threading/worker_pool.h" #include "net/base/completion_callback.h" #include "net/base/test_completion_callback.h" diff --git a/net/http/http_auth_handler_factory.h b/net/http/http_auth_handler_factory.h index a56d5e1..1e4134f 100644 --- a/net/http/http_auth_handler_factory.h +++ b/net/http/http_auth_handler_factory.h @@ -28,6 +28,11 @@ class HttpAuthHandlerRegistryFactory; // objects that it creates. class HttpAuthHandlerFactory { public: + enum CreateReason { + CREATE_CHALLENGE, // Create a handler in response to a challenge. + CREATE_PREEMPTIVE, // Create a handler preemptively. + }; + HttpAuthHandlerFactory() : url_security_manager_(NULL) {} virtual ~HttpAuthHandlerFactory() {} @@ -42,11 +47,6 @@ class HttpAuthHandlerFactory { return url_security_manager_; } - enum CreateReason { - CREATE_CHALLENGE, // Create a handler in response to a challenge. - CREATE_PREEMPTIVE, // Create a handler preemptively. - }; - // Creates an HttpAuthHandler object based on the authentication // challenge specified by |*challenge|. |challenge| must point to a valid // non-NULL tokenizer. diff --git a/net/http/http_cache.h b/net/http/http_cache.h index 3438ba7..b431ee6 100644 --- a/net/http/http_cache.h +++ b/net/http/http_cache.h @@ -225,6 +225,7 @@ class HttpCache : public HttpTransactionFactory, typedef base::hash_map<std::string, ActiveEntry*> ActiveEntriesMap; typedef base::hash_map<std::string, PendingOp*> PendingOpsMap; typedef std::set<ActiveEntry*> ActiveEntriesSet; + typedef base::hash_map<std::string, int> PlaybackCacheMap; // Methods ------------------------------------------------------------------ @@ -371,7 +372,6 @@ class HttpCache : public HttpTransactionFactory, ScopedRunnableMethodFactory<HttpCache> task_factory_; - typedef base::hash_map<std::string, int> PlaybackCacheMap; scoped_ptr<PlaybackCacheMap> playback_cache_map_; DISALLOW_COPY_AND_ASSIGN(HttpCache); diff --git a/net/http/http_proxy_client_socket_pool.cc b/net/http/http_proxy_client_socket_pool.cc index d2e3ccb..3129133 100644 --- a/net/http/http_proxy_client_socket_pool.cc +++ b/net/http/http_proxy_client_socket_pool.cc @@ -105,12 +105,11 @@ LoadState HttpProxyConnectJob::GetLoadState() const { } } -int HttpProxyConnectJob::ConnectInternal() { - if (params_->tcp_params()) - next_state_ = STATE_TCP_CONNECT; - else - next_state_ = STATE_SSL_CONNECT; - return DoLoop(OK); +void HttpProxyConnectJob::GetAdditionalErrorState(ClientSocketHandle * handle) { + if (error_response_info_.cert_request_info) { + handle->set_ssl_error_response_info(error_response_info_); + handle->set_is_ssl_error(true); + } } void HttpProxyConnectJob::OnIOComplete(int result) { @@ -248,11 +247,33 @@ int HttpProxyConnectJob::DoSSLConnectComplete(int result) { return result; } -void HttpProxyConnectJob::GetAdditionalErrorState(ClientSocketHandle * handle) { - if (error_response_info_.cert_request_info) { - handle->set_ssl_error_response_info(error_response_info_); - handle->set_is_ssl_error(true); +int HttpProxyConnectJob::DoHttpProxyConnect() { + next_state_ = STATE_HTTP_PROXY_CONNECT_COMPLETE; + const HostResolver::RequestInfo& tcp_destination = params_->destination(); + const HostPortPair& proxy_server = tcp_destination.host_port_pair(); + + // Add a HttpProxy connection on top of the tcp socket. + transport_socket_.reset( + new HttpProxyClientSocket(transport_socket_handle_.release(), + params_->request_url(), + params_->user_agent(), + params_->endpoint(), + proxy_server, + params_->http_auth_cache(), + params_->http_auth_handler_factory(), + params_->tunnel(), + using_spdy_, + params_->ssl_params() != NULL)); + return transport_socket_->Connect(&callback_); +} + +int HttpProxyConnectJob::DoHttpProxyConnectComplete(int result) { + if (result == OK || result == ERR_PROXY_AUTH_REQUESTED || + result == ERR_HTTPS_PROXY_TUNNEL_RESPONSE) { + set_socket(transport_socket_.release()); } + + return result; } int HttpProxyConnectJob::DoSpdyProxyCreateStream() { @@ -303,33 +324,12 @@ int HttpProxyConnectJob::DoSpdyProxyCreateStreamComplete(int result) { return transport_socket_->Connect(&callback_); } -int HttpProxyConnectJob::DoHttpProxyConnect() { - next_state_ = STATE_HTTP_PROXY_CONNECT_COMPLETE; - const HostResolver::RequestInfo& tcp_destination = params_->destination(); - const HostPortPair& proxy_server = tcp_destination.host_port_pair(); - - // Add a HttpProxy connection on top of the tcp socket. - transport_socket_.reset( - new HttpProxyClientSocket(transport_socket_handle_.release(), - params_->request_url(), - params_->user_agent(), - params_->endpoint(), - proxy_server, - params_->http_auth_cache(), - params_->http_auth_handler_factory(), - params_->tunnel(), - using_spdy_, - params_->ssl_params() != NULL)); - return transport_socket_->Connect(&callback_); -} - -int HttpProxyConnectJob::DoHttpProxyConnectComplete(int result) { - if (result == OK || result == ERR_PROXY_AUTH_REQUESTED || - result == ERR_HTTPS_PROXY_TUNNEL_RESPONSE) { - set_socket(transport_socket_.release()); - } - - return result; +int HttpProxyConnectJob::ConnectInternal() { + if (params_->tcp_params()) + next_state_ = STATE_TCP_CONNECT; + else + next_state_ = STATE_SSL_CONNECT; + return DoLoop(OK); } HttpProxyClientSocketPool:: diff --git a/net/http/http_proxy_client_socket_pool.h b/net/http/http_proxy_client_socket_pool.h index 91963d8..4757b27 100644 --- a/net/http/http_proxy_client_socket_pool.h +++ b/net/http/http_proxy_client_socket_pool.h @@ -123,15 +123,6 @@ class HttpProxyConnectJob : public ConnectJob { STATE_NONE, }; - // Begins the tcp connection and the optional Http proxy tunnel. If the - // request is not immediately servicable (likely), the request will return - // ERR_IO_PENDING. An OK return from this function or the callback means - // that the connection is established; ERR_PROXY_AUTH_REQUESTED means - // that the tunnel needs authentication credentials, the socket will be - // returned in this case, and must be release back to the pool; or - // a standard net error code will be returned. - virtual int ConnectInternal(); - void OnIOComplete(int result); // Runs the state transition loop. @@ -150,6 +141,15 @@ class HttpProxyConnectJob : public ConnectJob { int DoSpdyProxyCreateStream(); int DoSpdyProxyCreateStreamComplete(int result); + // Begins the tcp connection and the optional Http proxy tunnel. If the + // request is not immediately servicable (likely), the request will return + // ERR_IO_PENDING. An OK return from this function or the callback means + // that the connection is established; ERR_PROXY_AUTH_REQUESTED means + // that the tunnel needs authentication credentials, the socket will be + // returned in this case, and must be release back to the pool; or + // a standard net error code will be returned. + virtual int ConnectInternal(); + scoped_refptr<HttpProxySocketParams> params_; TCPClientSocketPool* const tcp_pool_; SSLClientSocketPool* const ssl_pool_; diff --git a/net/net.gyp b/net/net.gyp index 485ac38..7dedb46 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -189,7 +189,6 @@ 'base/transport_security_state.cc', 'base/transport_security_state.h', 'base/sys_addrinfo.h', - 'base/test_completion_callback.h', 'base/upload_data.cc', 'base/upload_data.h', 'base/upload_data_stream.cc', @@ -1159,6 +1158,8 @@ 'sources': [ 'base/cert_test_util.cc', 'base/cert_test_util.h', + 'base/test_completion_callback.cc', + 'base/test_completion_callback.h', 'disk_cache/disk_cache_test_util.cc', 'disk_cache/disk_cache_test_util.h', 'proxy/proxy_config_service_common_unittest.cc', @@ -1422,6 +1423,7 @@ 'type': 'executable', 'dependencies': [ 'net', + 'net_test_support', '../base/base.gyp:base', ], 'sources': [ diff --git a/net/proxy/multi_threaded_proxy_resolver_unittest.cc b/net/proxy/multi_threaded_proxy_resolver_unittest.cc index 78409e5..8d2907c 100644 --- a/net/proxy/multi_threaded_proxy_resolver_unittest.cc +++ b/net/proxy/multi_threaded_proxy_resolver_unittest.cc @@ -4,6 +4,7 @@ #include "net/proxy/multi_threaded_proxy_resolver.h" +#include "base/message_loop.h" #include "base/stl_util-inl.h" #include "base/string_util.h" #include "base/stringprintf.h" diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc index a7bebed..d5ee9ae 100644 --- a/net/socket/socket_test_util.cc +++ b/net/socket/socket_test_util.cc @@ -116,451 +116,6 @@ void DumpMockRead(const MockRead& r) { } // namespace -MockClientSocket::MockClientSocket(net::NetLog* net_log) - : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), - connected_(false), - net_log_(NetLog::Source(), net_log) { -} - -void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { - NOTREACHED(); -} - -void MockClientSocket::GetSSLCertRequestInfo( - net::SSLCertRequestInfo* cert_request_info) { -} - -SSLClientSocket::NextProtoStatus -MockClientSocket::GetNextProto(std::string* proto) { - proto->clear(); - return SSLClientSocket::kNextProtoUnsupported; -} - -void MockClientSocket::Disconnect() { - connected_ = false; -} - -bool MockClientSocket::IsConnected() const { - return connected_; -} - -bool MockClientSocket::IsConnectedAndIdle() const { - return connected_; -} - -int MockClientSocket::GetPeerAddress(AddressList* address) const { - return net::SystemHostResolverProc("localhost", ADDRESS_FAMILY_UNSPECIFIED, - 0, address, NULL); -} - -void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback, - int result) { - MessageLoop::current()->PostTask(FROM_HERE, - method_factory_.NewRunnableMethod( - &MockClientSocket::RunCallback, callback, result)); -} - -void MockClientSocket::RunCallback(net::CompletionCallback* callback, - int result) { - if (callback) - callback->Run(result); -} - -MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses, - net::NetLog* net_log, - net::SocketDataProvider* data) - : MockClientSocket(net_log), - addresses_(addresses), - data_(data), - read_offset_(0), - read_data_(false, net::ERR_UNEXPECTED), - need_read_data_(true), - peer_closed_connection_(false), - pending_buf_(NULL), - pending_buf_len_(0), - pending_callback_(NULL), - was_used_to_convey_data_(false) { - DCHECK(data_); - data_->Reset(); -} - -int MockTCPClientSocket::Connect(net::CompletionCallback* callback) { - if (connected_) - return net::OK; - connected_ = true; - peer_closed_connection_ = false; - if (data_->connect_data().async) { - RunCallbackAsync(callback, data_->connect_data().result); - return net::ERR_IO_PENDING; - } - return data_->connect_data().result; -} - -void MockTCPClientSocket::Disconnect() { - MockClientSocket::Disconnect(); - pending_callback_ = NULL; -} - -bool MockTCPClientSocket::IsConnected() const { - return connected_ && !peer_closed_connection_; -} - -int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) { - if (!connected_) - return net::ERR_UNEXPECTED; - - // If the buffer is already in use, a read is already in progress! - DCHECK(pending_buf_ == NULL); - - // Store our async IO data. - pending_buf_ = buf; - pending_buf_len_ = buf_len; - pending_callback_ = callback; - - if (need_read_data_) { - read_data_ = data_->GetNextRead(); - if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { - // This MockRead is just a marker to instruct us to set - // peer_closed_connection_. Skip it and get the next one. - read_data_ = data_->GetNextRead(); - peer_closed_connection_ = true; - } - // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility - // to complete the async IO manually later (via OnReadComplete). - if (read_data_.result == ERR_IO_PENDING) { - DCHECK(callback); // We need to be using async IO in this case. - return ERR_IO_PENDING; - } - need_read_data_ = false; - } - - return CompleteRead(); -} - -int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) { - DCHECK(buf); - DCHECK_GT(buf_len, 0); - - if (!connected_) - return net::ERR_UNEXPECTED; - - std::string data(buf->data(), buf_len); - net::MockWriteResult write_result = data_->OnWrite(data); - - was_used_to_convey_data_ = true; - - if (write_result.async) { - RunCallbackAsync(callback, write_result.result); - return net::ERR_IO_PENDING; - } - - return write_result.result; -} - -void MockTCPClientSocket::OnReadComplete(const MockRead& data) { - // There must be a read pending. - DCHECK(pending_buf_); - // You can't complete a read with another ERR_IO_PENDING status code. - DCHECK_NE(ERR_IO_PENDING, data.result); - // Since we've been waiting for data, need_read_data_ should be true. - DCHECK(need_read_data_); - - read_data_ = data; - need_read_data_ = false; - - // The caller is simulating that this IO completes right now. Don't - // let CompleteRead() schedule a callback. - read_data_.async = false; - - net::CompletionCallback* callback = pending_callback_; - int rv = CompleteRead(); - RunCallback(callback, rv); -} - -int MockTCPClientSocket::CompleteRead() { - DCHECK(pending_buf_); - DCHECK(pending_buf_len_ > 0); - - was_used_to_convey_data_ = true; - - // Save the pending async IO data and reset our |pending_| state. - net::IOBuffer* buf = pending_buf_; - int buf_len = pending_buf_len_; - net::CompletionCallback* callback = pending_callback_; - pending_buf_ = NULL; - pending_buf_len_ = 0; - pending_callback_ = NULL; - - int result = read_data_.result; - DCHECK(result != ERR_IO_PENDING); - - if (read_data_.data) { - if (read_data_.data_len - read_offset_ > 0) { - result = std::min(buf_len, read_data_.data_len - read_offset_); - memcpy(buf->data(), read_data_.data + read_offset_, result); - read_offset_ += result; - if (read_offset_ == read_data_.data_len) { - need_read_data_ = true; - read_offset_ = 0; - } - } else { - result = 0; // EOF - } - } - - if (read_data_.async) { - DCHECK(callback); - RunCallbackAsync(callback, result); - return net::ERR_IO_PENDING; - } - return result; -} - -DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket( - net::NetLog* net_log, net::DeterministicSocketData* data) - : MockClientSocket(net_log), - write_pending_(false), - write_callback_(NULL), - write_result_(0), - read_data_(), - read_buf_(NULL), - read_buf_len_(0), - read_pending_(false), - read_callback_(NULL), - data_(data), - was_used_to_convey_data_(false) {} - -void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {} - -// TODO(erikchen): Support connect sequencing. -int DeterministicMockTCPClientSocket::Connect( - net::CompletionCallback* callback) { - if (connected_) - return net::OK; - connected_ = true; - if (data_->connect_data().async) { - RunCallbackAsync(callback, data_->connect_data().result); - return net::ERR_IO_PENDING; - } - return data_->connect_data().result; -} - -void DeterministicMockTCPClientSocket::Disconnect() { - MockClientSocket::Disconnect(); -} - -bool DeterministicMockTCPClientSocket::IsConnected() const { - return connected_; -} - -int DeterministicMockTCPClientSocket::Write( - net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) { - DCHECK(buf); - DCHECK_GT(buf_len, 0); - - if (!connected_) - return net::ERR_UNEXPECTED; - - std::string data(buf->data(), buf_len); - net::MockWriteResult write_result = data_->OnWrite(data); - - if (write_result.async) { - write_callback_ = callback; - write_result_ = write_result.result; - DCHECK(write_callback_ != NULL); - write_pending_ = true; - return net::ERR_IO_PENDING; - } - - was_used_to_convey_data_ = true; - write_pending_ = false; - return write_result.result; -} - -int DeterministicMockTCPClientSocket::Read( - net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) { - if (!connected_) - return net::ERR_UNEXPECTED; - - read_data_ = data_->GetNextRead(); - // The buffer should always be big enough to contain all the MockRead data. To - // use small buffers, split the data into multiple MockReads. - DCHECK_LE(read_data_.data_len, buf_len); - - read_buf_ = buf; - read_buf_len_ = buf_len; - read_callback_ = callback; - - if (read_data_.async || (read_data_.result == ERR_IO_PENDING)) { - read_pending_ = true; - DCHECK(read_callback_); - return ERR_IO_PENDING; - } - - was_used_to_convey_data_ = true; - return CompleteRead(); -} - -void DeterministicMockTCPClientSocket::CompleteWrite() { - was_used_to_convey_data_ = true; - write_pending_ = false; - write_callback_->Run(write_result_); -} - -int DeterministicMockTCPClientSocket::CompleteRead() { - DCHECK_GT(read_buf_len_, 0); - DCHECK_LE(read_data_.data_len, read_buf_len_); - DCHECK(read_buf_); - - was_used_to_convey_data_ = true; - - if (read_data_.result == ERR_IO_PENDING) - read_data_ = data_->GetNextRead(); - DCHECK_NE(ERR_IO_PENDING, read_data_.result); - // If read_data_.async is true, we do not need to wait, since this is already - // the callback. Therefore we don't even bother to check it. - int result = read_data_.result; - - if (read_data_.data_len > 0) { - DCHECK(read_data_.data); - result = std::min(read_buf_len_, read_data_.data_len); - memcpy(read_buf_->data(), read_data_.data, result); - } - - if (read_pending_) { - read_pending_ = false; - read_callback_->Run(result); - } - - return result; -} - -class MockSSLClientSocket::ConnectCallback - : public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> { - public: - ConnectCallback(MockSSLClientSocket *ssl_client_socket, - net::CompletionCallback* user_callback, - int rv) - : ALLOW_THIS_IN_INITIALIZER_LIST( - net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>( - this, &ConnectCallback::Wrapper)), - ssl_client_socket_(ssl_client_socket), - user_callback_(user_callback), - rv_(rv) { - } - - private: - void Wrapper(int rv) { - if (rv_ == net::OK) - ssl_client_socket_->connected_ = true; - user_callback_->Run(rv_); - delete this; - } - - MockSSLClientSocket* ssl_client_socket_; - net::CompletionCallback* user_callback_; - int rv_; -}; - -MockSSLClientSocket::MockSSLClientSocket( - net::ClientSocketHandle* transport_socket, - const HostPortPair& host_port_pair, - const net::SSLConfig& ssl_config, - SSLHostInfo* ssl_host_info, - net::SSLSocketDataProvider* data) - : MockClientSocket(transport_socket->socket()->NetLog().net_log()), - transport_(transport_socket), - data_(data), - is_npn_state_set_(false), - new_npn_value_(false) { - DCHECK(data_); - delete ssl_host_info; // we take ownership but don't use it. -} - -MockSSLClientSocket::~MockSSLClientSocket() { - Disconnect(); -} - -int MockSSLClientSocket::Connect(net::CompletionCallback* callback) { - ConnectCallback* connect_callback = new ConnectCallback( - this, callback, data_->connect.result); - int rv = transport_->socket()->Connect(connect_callback); - if (rv == net::OK) { - delete connect_callback; - if (data_->connect.result == net::OK) - connected_ = true; - if (data_->connect.async) { - RunCallbackAsync(callback, data_->connect.result); - return net::ERR_IO_PENDING; - } - return data_->connect.result; - } - return rv; -} - -void MockSSLClientSocket::Disconnect() { - MockClientSocket::Disconnect(); - if (transport_->socket() != NULL) - transport_->socket()->Disconnect(); -} - -bool MockSSLClientSocket::IsConnected() const { - return transport_->socket()->IsConnected(); -} - -bool MockSSLClientSocket::WasEverUsed() const { - return transport_->socket()->WasEverUsed(); -} - -bool MockSSLClientSocket::UsingTCPFastOpen() const { - return transport_->socket()->UsingTCPFastOpen(); -} - -int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) { - return transport_->socket()->Read(buf, buf_len, callback); -} - -int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) { - return transport_->socket()->Write(buf, buf_len, callback); -} - -void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { - ssl_info->Reset(); -} - -void MockSSLClientSocket::GetSSLCertRequestInfo( - net::SSLCertRequestInfo* cert_request_info) { - DCHECK(cert_request_info); - if (data_->cert_request_info) { - cert_request_info->host_and_port = - data_->cert_request_info->host_and_port; - cert_request_info->client_certs = data_->cert_request_info->client_certs; - } else { - cert_request_info->Reset(); - } -} - -SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto( - std::string* proto) { - *proto = data_->next_proto; - return data_->next_proto_status; -} - -bool MockSSLClientSocket::was_npn_negotiated() const { - if (is_npn_state_set_) - return new_npn_value_; - return data_->was_npn_negotiated; -} - -bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) { - is_npn_state_set_ = true; - return new_npn_value_ = negotiated; -} - StaticSocketDataProvider::StaticSocketDataProvider() : reads_(NULL), read_index_(0), @@ -584,6 +139,26 @@ StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads, StaticSocketDataProvider::~StaticSocketDataProvider() {} +const MockRead& StaticSocketDataProvider::PeekRead() const { + DCHECK(!at_read_eof()); + return reads_[read_index_]; +} + +const MockWrite& StaticSocketDataProvider::PeekWrite() const { + DCHECK(!at_write_eof()); + return writes_[write_index_]; +} + +const MockRead& StaticSocketDataProvider::PeekRead(size_t index) const { + DCHECK_LT(index, read_count_); + return reads_[index]; +} + +const MockWrite& StaticSocketDataProvider::PeekWrite(size_t index) const { + DCHECK_LT(index, write_count_); + return writes_[index]; +} + MockRead StaticSocketDataProvider::GetNextRead() { DCHECK(!at_read_eof()); reads_[read_index_].time_stamp = base::Time::Now(); @@ -622,26 +197,6 @@ MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) { return MockWriteResult(w->async, result); } -const MockRead& StaticSocketDataProvider::PeekRead() const { - DCHECK(!at_read_eof()); - return reads_[read_index_]; -} - -const MockWrite& StaticSocketDataProvider::PeekWrite() const { - DCHECK(!at_write_eof()); - return writes_[write_index_]; -} - -const MockRead& StaticSocketDataProvider::PeekRead(size_t index) const { - DCHECK_LT(index, read_count_); - return reads_[index]; -} - -const MockWrite& StaticSocketDataProvider::PeekWrite(size_t index) const { - DCHECK_LT(index, write_count_); - return writes_[index]; -} - void StaticSocketDataProvider::Reset() { read_index_ = 0; write_index_ = 0; @@ -702,6 +257,11 @@ DelayedSocketData::DelayedSocketData( DelayedSocketData::~DelayedSocketData() { } +void DelayedSocketData::ForceNextRead() { + write_delay_ = 0; + CompleteRead(); +} + MockRead DelayedSocketData::GetNextRead() { if (write_delay_ > 0) return MockRead(true, ERR_IO_PENDING); @@ -728,11 +288,6 @@ void DelayedSocketData::CompleteRead() { socket()->OnReadComplete(GetNextRead()); } -void DelayedSocketData::ForceNextRead() { - write_delay_ = 0; - CompleteRead(); -} - OrderedSocketData::OrderedSocketData( MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count) : StaticSocketDataProvider(reads, reads_count, writes, writes_count), @@ -750,6 +305,29 @@ OrderedSocketData::OrderedSocketData( set_connect_data(connect); } +void OrderedSocketData::EndLoop() { + // If we've already stopped the loop, don't do it again until we've advanced + // to the next sequence_number. + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ << ": EndLoop()"; + if (loop_stop_stage_ > 0) { + const MockRead& next_read = StaticSocketDataProvider::PeekRead(); + if ((next_read.sequence_number & ~MockRead::STOPLOOP) > + loop_stop_stage_) { + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ + << ": Clearing stop index"; + loop_stop_stage_ = 0; + } else { + return; + } + } + // Record the sequence_number at which we stopped the loop. + NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ + << ": Posting Quit at read " << read_index(); + loop_stop_stage_ = sequence_number_; + if (callback_) + callback_->RunWithParams(Tuple1<int>(ERR_IO_PENDING)); +} + MockRead OrderedSocketData::GetNextRead() { factory_.RevokeAll(); blocked_ = false; @@ -799,29 +377,6 @@ void OrderedSocketData::Reset() { StaticSocketDataProvider::Reset(); } -void OrderedSocketData::EndLoop() { - // If we've already stopped the loop, don't do it again until we've advanced - // to the next sequence_number. - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ << ": EndLoop()"; - if (loop_stop_stage_ > 0) { - const MockRead& next_read = StaticSocketDataProvider::PeekRead(); - if ((next_read.sequence_number & ~MockRead::STOPLOOP) > - loop_stop_stage_) { - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - << ": Clearing stop index"; - loop_stop_stage_ = 0; - } else { - return; - } - } - // Record the sequence_number at which we stopped the loop. - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - << ": Posting Quit at read " << read_index(); - loop_stop_stage_ = sequence_number_; - if (callback_) - callback_->RunWithParams(Tuple1<int>(ERR_IO_PENDING)); -} - void OrderedSocketData::CompleteRead() { if (socket()) { NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_; @@ -841,6 +396,49 @@ DeterministicSocketData::DeterministicSocketData(MockRead* reads, stopped_(false), print_debug_(false) {} +DeterministicSocketData::~DeterministicSocketData() {} + +void DeterministicSocketData::Run() { + SetStopped(false); + int counter = 0; + // Continue to consume data until all data has run out, or the stopped_ flag + // has been set. Consuming data requires two separate operations -- running + // the tasks in the message loop, and explicitly invoking the read/write + // callbacks (simulating network I/O). We check our conditions between each, + // since they can change in either. + while ((!at_write_eof() || !at_read_eof()) && !stopped()) { + if (counter % 2 == 0) + MessageLoop::current()->RunAllPending(); + if (counter % 2 == 1) { + InvokeCallbacks(); + } + counter++; + } + // We're done consuming new data, but it is possible there are still some + // pending callbacks which we expect to complete before returning. + while (socket_ && (socket_->write_pending() || socket_->read_pending()) && + !stopped()) { + InvokeCallbacks(); + MessageLoop::current()->RunAllPending(); + } + SetStopped(false); +} + +void DeterministicSocketData::RunFor(int steps) { + StopAfter(steps); + Run(); +} + +void DeterministicSocketData::SetStop(int seq) { + DCHECK_LT(sequence_number_, seq); + stopping_sequence_number_ = seq; + stopped_ = false; +} + +void DeterministicSocketData::StopAfter(int seq) { + SetStop(sequence_number_ + seq); +} + MockRead DeterministicSocketData::GetNextRead() { current_read_ = StaticSocketDataProvider::PeekRead(); EXPECT_LE(sequence_number_, current_read_.sequence_number); @@ -926,37 +524,6 @@ void DeterministicSocketData::Reset() { NOTREACHED(); } -void DeterministicSocketData::RunFor(int steps) { - StopAfter(steps); - Run(); -} - -void DeterministicSocketData::Run() { - SetStopped(false); - int counter = 0; - // Continue to consume data until all data has run out, or the stopped_ flag - // has been set. Consuming data requires two separate operations -- running - // the tasks in the message loop, and explicitly invoking the read/write - // callbacks (simulating network I/O). We check our conditions between each, - // since they can change in either. - while ((!at_write_eof() || !at_read_eof()) && !stopped()) { - if (counter % 2 == 0) - MessageLoop::current()->RunAllPending(); - if (counter % 2 == 1) { - InvokeCallbacks(); - } - counter++; - } - // We're done consuming new data, but it is possible there are still some - // pending callbacks which we expect to complete before returning. - while (socket_ && (socket_->write_pending() || socket_->read_pending()) && - !stopped()) { - InvokeCallbacks(); - MessageLoop::current()->RunAllPending(); - } - SetStopped(false); -} - void DeterministicSocketData::InvokeCallbacks() { if (socket_ && socket_->write_pending() && (current_write().sequence_number == sequence_number())) { @@ -980,7 +547,6 @@ void DeterministicSocketData::NextStep() { SetStopped(true); } - MockClientSocketFactory::MockClientSocketFactory() {} MockClientSocketFactory::~MockClientSocketFactory() {} @@ -1038,55 +604,493 @@ SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( return socket; } -DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {} +MockClientSocket::MockClientSocket(net::NetLog* net_log) + : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), + connected_(false), + net_log_(NetLog::Source(), net_log) { +} -DeterministicMockClientSocketFactory::~DeterministicMockClientSocketFactory() {} +bool MockClientSocket::SetReceiveBufferSize(int32 size) { + return true; +} -void DeterministicMockClientSocketFactory::AddSocketDataProvider( - DeterministicSocketData* data) { - mock_data_.Add(data); +bool MockClientSocket::SetSendBufferSize(int32 size) { + return true; } -void DeterministicMockClientSocketFactory::AddSSLSocketDataProvider( - SSLSocketDataProvider* data) { - mock_ssl_data_.Add(data); +void MockClientSocket::Disconnect() { + connected_ = false; } -void DeterministicMockClientSocketFactory::ResetNextMockIndexes() { - mock_data_.ResetNextIndex(); - mock_ssl_data_.ResetNextIndex(); +bool MockClientSocket::IsConnected() const { + return connected_; } -MockSSLClientSocket* DeterministicMockClientSocketFactory:: - GetMockSSLClientSocket(size_t index) const { - DCHECK_LT(index, ssl_client_sockets_.size()); - return ssl_client_sockets_[index]; +bool MockClientSocket::IsConnectedAndIdle() const { + return connected_; } -ClientSocket* DeterministicMockClientSocketFactory::CreateTCPClientSocket( - const AddressList& addresses, - net::NetLog* net_log, - const net::NetLog::Source& source) { - DeterministicSocketData* data_provider = mock_data().GetNext(); - DeterministicMockTCPClientSocket* socket = - new DeterministicMockTCPClientSocket(net_log, data_provider); - data_provider->set_socket(socket->AsWeakPtr()); - tcp_client_sockets().push_back(socket); - return socket; +int MockClientSocket::GetPeerAddress(AddressList* address) const { + return net::SystemHostResolverProc("localhost", ADDRESS_FAMILY_UNSPECIFIED, + 0, address, NULL); } -SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket( - ClientSocketHandle* transport_socket, - const HostPortPair& host_and_port, - const SSLConfig& ssl_config, +const BoundNetLog& MockClientSocket::NetLog() const { + return net_log_; +} + +void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { + NOTREACHED(); +} + +void MockClientSocket::GetSSLCertRequestInfo( + net::SSLCertRequestInfo* cert_request_info) { +} + +SSLClientSocket::NextProtoStatus +MockClientSocket::GetNextProto(std::string* proto) { + proto->clear(); + return SSLClientSocket::kNextProtoUnsupported; +} + +MockClientSocket::~MockClientSocket() {} + +void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback, + int result) { + MessageLoop::current()->PostTask(FROM_HERE, + method_factory_.NewRunnableMethod( + &MockClientSocket::RunCallback, callback, result)); +} + +void MockClientSocket::RunCallback(net::CompletionCallback* callback, + int result) { + if (callback) + callback->Run(result); +} + +MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses, + net::NetLog* net_log, + net::SocketDataProvider* data) + : MockClientSocket(net_log), + addresses_(addresses), + data_(data), + read_offset_(0), + read_data_(false, net::ERR_UNEXPECTED), + need_read_data_(true), + peer_closed_connection_(false), + pending_buf_(NULL), + pending_buf_len_(0), + pending_callback_(NULL), + was_used_to_convey_data_(false) { + DCHECK(data_); + data_->Reset(); +} + +int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) { + if (!connected_) + return net::ERR_UNEXPECTED; + + // If the buffer is already in use, a read is already in progress! + DCHECK(pending_buf_ == NULL); + + // Store our async IO data. + pending_buf_ = buf; + pending_buf_len_ = buf_len; + pending_callback_ = callback; + + if (need_read_data_) { + read_data_ = data_->GetNextRead(); + if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { + // This MockRead is just a marker to instruct us to set + // peer_closed_connection_. Skip it and get the next one. + read_data_ = data_->GetNextRead(); + peer_closed_connection_ = true; + } + // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility + // to complete the async IO manually later (via OnReadComplete). + if (read_data_.result == ERR_IO_PENDING) { + DCHECK(callback); // We need to be using async IO in this case. + return ERR_IO_PENDING; + } + need_read_data_ = false; + } + + return CompleteRead(); +} + +int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return net::ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + net::MockWriteResult write_result = data_->OnWrite(data); + + was_used_to_convey_data_ = true; + + if (write_result.async) { + RunCallbackAsync(callback, write_result.result); + return net::ERR_IO_PENDING; + } + + return write_result.result; +} + +int MockTCPClientSocket::Connect(net::CompletionCallback* callback) { + if (connected_) + return net::OK; + connected_ = true; + peer_closed_connection_ = false; + if (data_->connect_data().async) { + RunCallbackAsync(callback, data_->connect_data().result); + return net::ERR_IO_PENDING; + } + return data_->connect_data().result; +} + +void MockTCPClientSocket::Disconnect() { + MockClientSocket::Disconnect(); + pending_callback_ = NULL; +} + +bool MockTCPClientSocket::IsConnected() const { + return connected_ && !peer_closed_connection_; +} + +bool MockTCPClientSocket::IsConnectedAndIdle() const { + return IsConnected(); +} + +bool MockTCPClientSocket::WasEverUsed() const { + return was_used_to_convey_data_; +} + +bool MockTCPClientSocket::UsingTCPFastOpen() const { + return false; +} + +void MockTCPClientSocket::OnReadComplete(const MockRead& data) { + // There must be a read pending. + DCHECK(pending_buf_); + // You can't complete a read with another ERR_IO_PENDING status code. + DCHECK_NE(ERR_IO_PENDING, data.result); + // Since we've been waiting for data, need_read_data_ should be true. + DCHECK(need_read_data_); + + read_data_ = data; + need_read_data_ = false; + + // The caller is simulating that this IO completes right now. Don't + // let CompleteRead() schedule a callback. + read_data_.async = false; + + net::CompletionCallback* callback = pending_callback_; + int rv = CompleteRead(); + RunCallback(callback, rv); +} + +int MockTCPClientSocket::CompleteRead() { + DCHECK(pending_buf_); + DCHECK(pending_buf_len_ > 0); + + was_used_to_convey_data_ = true; + + // Save the pending async IO data and reset our |pending_| state. + net::IOBuffer* buf = pending_buf_; + int buf_len = pending_buf_len_; + net::CompletionCallback* callback = pending_callback_; + pending_buf_ = NULL; + pending_buf_len_ = 0; + pending_callback_ = NULL; + + int result = read_data_.result; + DCHECK(result != ERR_IO_PENDING); + + if (read_data_.data) { + if (read_data_.data_len - read_offset_ > 0) { + result = std::min(buf_len, read_data_.data_len - read_offset_); + memcpy(buf->data(), read_data_.data + read_offset_, result); + read_offset_ += result; + if (read_offset_ == read_data_.data_len) { + need_read_data_ = true; + read_offset_ = 0; + } + } else { + result = 0; // EOF + } + } + + if (read_data_.async) { + DCHECK(callback); + RunCallbackAsync(callback, result); + return net::ERR_IO_PENDING; + } + return result; +} + +DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket( + net::NetLog* net_log, net::DeterministicSocketData* data) + : MockClientSocket(net_log), + write_pending_(false), + write_callback_(NULL), + write_result_(0), + read_data_(), + read_buf_(NULL), + read_buf_len_(0), + read_pending_(false), + read_callback_(NULL), + data_(data), + was_used_to_convey_data_(false) {} + +DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {} + +void DeterministicMockTCPClientSocket::CompleteWrite() { + was_used_to_convey_data_ = true; + write_pending_ = false; + write_callback_->Run(write_result_); +} + +int DeterministicMockTCPClientSocket::CompleteRead() { + DCHECK_GT(read_buf_len_, 0); + DCHECK_LE(read_data_.data_len, read_buf_len_); + DCHECK(read_buf_); + + was_used_to_convey_data_ = true; + + if (read_data_.result == ERR_IO_PENDING) + read_data_ = data_->GetNextRead(); + DCHECK_NE(ERR_IO_PENDING, read_data_.result); + // If read_data_.async is true, we do not need to wait, since this is already + // the callback. Therefore we don't even bother to check it. + int result = read_data_.result; + + if (read_data_.data_len > 0) { + DCHECK(read_data_.data); + result = std::min(read_buf_len_, read_data_.data_len); + memcpy(read_buf_->data(), read_data_.data, result); + } + + if (read_pending_) { + read_pending_ = false; + read_callback_->Run(result); + } + + return result; +} + +int DeterministicMockTCPClientSocket::Write( + net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + + if (!connected_) + return net::ERR_UNEXPECTED; + + std::string data(buf->data(), buf_len); + net::MockWriteResult write_result = data_->OnWrite(data); + + if (write_result.async) { + write_callback_ = callback; + write_result_ = write_result.result; + DCHECK(write_callback_ != NULL); + write_pending_ = true; + return net::ERR_IO_PENDING; + } + + was_used_to_convey_data_ = true; + write_pending_ = false; + return write_result.result; +} + +int DeterministicMockTCPClientSocket::Read( + net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback) { + if (!connected_) + return net::ERR_UNEXPECTED; + + read_data_ = data_->GetNextRead(); + // The buffer should always be big enough to contain all the MockRead data. To + // use small buffers, split the data into multiple MockReads. + DCHECK_LE(read_data_.data_len, buf_len); + + read_buf_ = buf; + read_buf_len_ = buf_len; + read_callback_ = callback; + + if (read_data_.async || (read_data_.result == ERR_IO_PENDING)) { + read_pending_ = true; + DCHECK(read_callback_); + return ERR_IO_PENDING; + } + + was_used_to_convey_data_ = true; + return CompleteRead(); +} + +// TODO(erikchen): Support connect sequencing. +int DeterministicMockTCPClientSocket::Connect( + net::CompletionCallback* callback) { + if (connected_) + return net::OK; + connected_ = true; + if (data_->connect_data().async) { + RunCallbackAsync(callback, data_->connect_data().result); + return net::ERR_IO_PENDING; + } + return data_->connect_data().result; +} + +void DeterministicMockTCPClientSocket::Disconnect() { + MockClientSocket::Disconnect(); +} + +bool DeterministicMockTCPClientSocket::IsConnected() const { + return connected_; +} + +bool DeterministicMockTCPClientSocket::IsConnectedAndIdle() const { + return IsConnected(); +} + +bool DeterministicMockTCPClientSocket::WasEverUsed() const { + return was_used_to_convey_data_; +} + +bool DeterministicMockTCPClientSocket::UsingTCPFastOpen() const { + return false; +} + +void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {} + +class MockSSLClientSocket::ConnectCallback + : public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> { + public: + ConnectCallback(MockSSLClientSocket *ssl_client_socket, + net::CompletionCallback* user_callback, + int rv) + : ALLOW_THIS_IN_INITIALIZER_LIST( + net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>( + this, &ConnectCallback::Wrapper)), + ssl_client_socket_(ssl_client_socket), + user_callback_(user_callback), + rv_(rv) { + } + + private: + void Wrapper(int rv) { + if (rv_ == net::OK) + ssl_client_socket_->connected_ = true; + user_callback_->Run(rv_); + delete this; + } + + MockSSLClientSocket* ssl_client_socket_; + net::CompletionCallback* user_callback_; + int rv_; +}; + +MockSSLClientSocket::MockSSLClientSocket( + net::ClientSocketHandle* transport_socket, + const HostPortPair& host_port_pair, + const net::SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, - CertVerifier* cert_verifier, - DnsCertProvenanceChecker* dns_cert_checker) { - MockSSLClientSocket* socket = - new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, - ssl_host_info, mock_ssl_data_.GetNext()); - ssl_client_sockets_.push_back(socket); - return socket; + net::SSLSocketDataProvider* data) + : MockClientSocket(transport_socket->socket()->NetLog().net_log()), + transport_(transport_socket), + data_(data), + is_npn_state_set_(false), + new_npn_value_(false) { + DCHECK(data_); + delete ssl_host_info; // we take ownership but don't use it. +} + +MockSSLClientSocket::~MockSSLClientSocket() { + Disconnect(); +} + +int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) { + return transport_->socket()->Read(buf, buf_len, callback); +} + +int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) { + return transport_->socket()->Write(buf, buf_len, callback); +} + +int MockSSLClientSocket::Connect(net::CompletionCallback* callback) { + ConnectCallback* connect_callback = new ConnectCallback( + this, callback, data_->connect.result); + int rv = transport_->socket()->Connect(connect_callback); + if (rv == net::OK) { + delete connect_callback; + if (data_->connect.result == net::OK) + connected_ = true; + if (data_->connect.async) { + RunCallbackAsync(callback, data_->connect.result); + return net::ERR_IO_PENDING; + } + return data_->connect.result; + } + return rv; +} + +void MockSSLClientSocket::Disconnect() { + MockClientSocket::Disconnect(); + if (transport_->socket() != NULL) + transport_->socket()->Disconnect(); +} + +bool MockSSLClientSocket::IsConnected() const { + return transport_->socket()->IsConnected(); +} + +bool MockSSLClientSocket::WasEverUsed() const { + return transport_->socket()->WasEverUsed(); +} + +bool MockSSLClientSocket::UsingTCPFastOpen() const { + return transport_->socket()->UsingTCPFastOpen(); +} + +void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { + ssl_info->Reset(); +} + +void MockSSLClientSocket::GetSSLCertRequestInfo( + net::SSLCertRequestInfo* cert_request_info) { + DCHECK(cert_request_info); + if (data_->cert_request_info) { + cert_request_info->host_and_port = + data_->cert_request_info->host_and_port; + cert_request_info->client_certs = data_->cert_request_info->client_certs; + } else { + cert_request_info->Reset(); + } +} + +SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto( + std::string* proto) { + *proto = data_->next_proto; + return data_->next_proto_status; +} + +bool MockSSLClientSocket::was_npn_negotiated() const { + if (is_npn_state_set_) + return new_npn_value_; + return data_->was_npn_negotiated; +} + +bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) { + is_npn_state_set_ = true; + return new_npn_value_ = negotiated; +} + +void MockSSLClientSocket::OnReadComplete(const MockRead& data) { + NOTIMPLEMENTED(); } TestSocketRequest::TestSocketRequest( @@ -1215,6 +1219,8 @@ MockTCPClientSocketPool::MockTCPClientSocketPool( cancel_count_(0) { } +MockTCPClientSocketPool::~MockTCPClientSocketPool() {} + int MockTCPClientSocketPool::RequestSocket(const std::string& group_name, const void* socket_params, RequestPriority priority, @@ -1247,7 +1253,56 @@ void MockTCPClientSocketPool::ReleaseSocket(const std::string& group_name, delete socket; } -MockTCPClientSocketPool::~MockTCPClientSocketPool() {} +DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {} + +DeterministicMockClientSocketFactory::~DeterministicMockClientSocketFactory() {} + +void DeterministicMockClientSocketFactory::AddSocketDataProvider( + DeterministicSocketData* data) { + mock_data_.Add(data); +} + +void DeterministicMockClientSocketFactory::AddSSLSocketDataProvider( + SSLSocketDataProvider* data) { + mock_ssl_data_.Add(data); +} + +void DeterministicMockClientSocketFactory::ResetNextMockIndexes() { + mock_data_.ResetNextIndex(); + mock_ssl_data_.ResetNextIndex(); +} + +MockSSLClientSocket* DeterministicMockClientSocketFactory:: + GetMockSSLClientSocket(size_t index) const { + DCHECK_LT(index, ssl_client_sockets_.size()); + return ssl_client_sockets_[index]; +} + +ClientSocket* DeterministicMockClientSocketFactory::CreateTCPClientSocket( + const AddressList& addresses, + net::NetLog* net_log, + const net::NetLog::Source& source) { + DeterministicSocketData* data_provider = mock_data().GetNext(); + DeterministicMockTCPClientSocket* socket = + new DeterministicMockTCPClientSocket(net_log, data_provider); + data_provider->set_socket(socket->AsWeakPtr()); + tcp_client_sockets().push_back(socket); + return socket; +} + +SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket( + ClientSocketHandle* transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + SSLHostInfo* ssl_host_info, + CertVerifier* cert_verifier, + DnsCertProvenanceChecker* dns_cert_checker) { + MockSSLClientSocket* socket = + new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, + ssl_host_info, mock_ssl_data_.GetNext()); + ssl_client_sockets_.push_back(socket); + return socket; +} MockSOCKSClientSocketPool::MockSOCKSClientSocketPool( int max_sockets, @@ -1259,6 +1314,8 @@ MockSOCKSClientSocketPool::MockSOCKSClientSocketPool( tcp_pool_(tcp_pool) { } +MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {} + int MockSOCKSClientSocketPool::RequestSocket(const std::string& group_name, const void* socket_params, RequestPriority priority, @@ -1280,8 +1337,6 @@ void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, return tcp_pool_->ReleaseSocket(group_name, socket, id); } -MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {} - const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 }; const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest); diff --git a/net/socket/socket_test_util.h b/net/socket/socket_test_util.h index 1e09708..4a15f37 100644 --- a/net/socket/socket_test_util.h +++ b/net/socket/socket_test_util.h @@ -182,11 +182,12 @@ class StaticSocketDataProvider : public SocketDataProvider { bool at_read_eof() const { return read_index_ >= read_count_; } bool at_write_eof() const { return write_index_ >= write_count_; } + virtual void CompleteRead() {} + // SocketDataProvider methods: virtual MockRead GetNextRead(); virtual MockWriteResult OnWrite(const std::string& data); virtual void Reset(); - virtual void CompleteRead() {} private: MockRead* reads_; @@ -208,16 +209,16 @@ class DynamicSocketDataProvider : public SocketDataProvider { DynamicSocketDataProvider(); virtual ~DynamicSocketDataProvider(); - // SocketDataProvider methods: - virtual MockRead GetNextRead(); - virtual MockWriteResult OnWrite(const std::string& data) = 0; - virtual void Reset(); - int short_read_limit() const { return short_read_limit_; } void set_short_read_limit(int limit) { short_read_limit_ = limit; } void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; } + // SocketDataProvider methods: + virtual MockRead GetNextRead(); + virtual MockWriteResult OnWrite(const std::string& data) = 0; + virtual void Reset(); + protected: // The next time there is a read from this socket, it will return |data|. // Before calling SimulateRead next time, the previous data must be consumed. @@ -284,11 +285,13 @@ class DelayedSocketData : public StaticSocketDataProvider, MockWrite* writes, size_t writes_count); ~DelayedSocketData(); + void ForceNextRead(); + + // StaticSocketDataProvider: virtual MockRead GetNextRead(); virtual MockWriteResult OnWrite(const std::string& data); virtual void Reset(); virtual void CompleteRead(); - void ForceNextRead(); private: int write_delay_; @@ -327,11 +330,6 @@ class OrderedSocketData : public StaticSocketDataProvider, MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count); - virtual MockRead GetNextRead(); - virtual MockWriteResult OnWrite(const std::string& data); - virtual void Reset(); - virtual void CompleteRead(); - void SetCompletionCallback(CompletionCallback* callback) { callback_ = callback; } @@ -339,6 +337,12 @@ class OrderedSocketData : public StaticSocketDataProvider, // Posts a quit message to the current message loop, if one is running. void EndLoop(); + // StaticSocketDataProvider: + virtual MockRead GetNextRead(); + virtual MockWriteResult OnWrite(const std::string& data); + virtual void Reset(); + virtual void CompleteRead(); + private: friend class base::RefCounted<OrderedSocketData>; virtual ~OrderedSocketData(); @@ -416,19 +420,7 @@ class DeterministicSocketData : public StaticSocketDataProvider, // |writes| the list of MockWrite completions. DeterministicSocketData(MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count); - - // When the socket calls Read(), that calls GetNextRead(), and expects either - // ERR_IO_PENDING or data. - virtual MockRead GetNextRead(); - - // When the socket calls Write(), it always completes synchronously. OnWrite() - // checks to make sure the written data matches the expected data. The - // callback will not be invoked until its sequence number is reached. - virtual MockWriteResult OnWrite(const std::string& data); - - virtual void Reset(); - - virtual void CompleteRead() {} + virtual ~DeterministicSocketData(); // Consume all the data up to the give stop point (via SetStop()). void Run(); @@ -437,16 +429,10 @@ class DeterministicSocketData : public StaticSocketDataProvider, void RunFor(int steps); // Stop at step |seq|, which must be in the future. - virtual void SetStop(int seq) { - DCHECK_LT(sequence_number_, seq); - stopping_sequence_number_ = seq; - stopped_ = false; - } + virtual void SetStop(int seq); // Stop |seq| steps after the current step. - virtual void StopAfter(int seq) { - SetStop(sequence_number_ + seq); - } + virtual void StopAfter(int seq); bool stopped() const { return stopped_; } void SetStopped(bool val) { stopped_ = val; } MockRead& current_read() { return current_read_; } @@ -456,6 +442,19 @@ class DeterministicSocketData : public StaticSocketDataProvider, socket_ = socket; } + // StaticSocketDataProvider: + + // When the socket calls Read(), that calls GetNextRead(), and expects either + // ERR_IO_PENDING or data. + virtual MockRead GetNextRead(); + + // When the socket calls Write(), it always completes synchronously. OnWrite() + // checks to make sure the written data matches the expected data. The + // callback will not be invoked until its sequence number is reached. + virtual MockWriteResult OnWrite(const std::string& data); + virtual void Reset(); + virtual void CompleteRead() {} + private: // Invoke the read and write callbacks, if the timing is appropriate. void InvokeCallbacks(); @@ -471,7 +470,6 @@ class DeterministicSocketData : public StaticSocketDataProvider, bool print_debug_; }; - // Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}ClientSocket // objects get instantiated, they take their data from the i'th element of this // array. @@ -529,6 +527,13 @@ class MockClientSocketFactory : public ClientSocketFactory { // created. MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const; + SocketDataProviderArray<SocketDataProvider>& mock_data() { + return mock_data_; + } + std::vector<MockTCPClientSocket*>& tcp_client_sockets() { + return tcp_client_sockets_; + } + // ClientSocketFactory virtual ClientSocket* CreateTCPClientSocket( const AddressList& addresses, @@ -541,12 +546,6 @@ class MockClientSocketFactory : public ClientSocketFactory { SSLHostInfo* ssl_host_info, CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker); - SocketDataProviderArray<SocketDataProvider>& mock_data() { - return mock_data_; - } - std::vector<MockTCPClientSocket*>& tcp_client_sockets() { - return tcp_client_sockets_; - } private: SocketDataProviderArray<SocketDataProvider> mock_data_; @@ -560,13 +559,29 @@ class MockClientSocketFactory : public ClientSocketFactory { class MockClientSocket : public net::SSLClientSocket { public: explicit MockClientSocket(net::NetLog* net_log); + + // If an async IO is pending because the SocketDataProvider returned + // ERR_IO_PENDING, then the MockClientSocket waits until this OnReadComplete + // is called to complete the asynchronous read operation. + // data.async is ignored, and this read is completed synchronously as + // part of this call. + virtual void OnReadComplete(const MockRead& data) = 0; + + // Socket methods: + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) = 0; + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback) = 0; + virtual bool SetReceiveBufferSize(int32 size); + virtual bool SetSendBufferSize(int32 size); + // ClientSocket methods: virtual int Connect(net::CompletionCallback* callback) = 0; virtual void Disconnect(); virtual bool IsConnected() const; virtual bool IsConnectedAndIdle() const; virtual int GetPeerAddress(AddressList* address) const; - virtual const BoundNetLog& NetLog() const { return net_log_;} + virtual const BoundNetLog& NetLog() const; virtual void SetSubresourceSpeculation() {} virtual void SetOmniboxSpeculation() {} @@ -576,23 +591,8 @@ class MockClientSocket : public net::SSLClientSocket { net::SSLCertRequestInfo* cert_request_info); virtual NextProtoStatus GetNextProto(std::string* proto); - // Socket methods: - virtual int Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) = 0; - virtual int Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback) = 0; - virtual bool SetReceiveBufferSize(int32 size) { return true; } - virtual bool SetSendBufferSize(int32 size) { return true; } - - // If an async IO is pending because the SocketDataProvider returned - // ERR_IO_PENDING, then the MockClientSocket waits until this OnReadComplete - // is called to complete the asynchronous read operation. - // data.async is ignored, and this read is completed synchronously as - // part of this call. - virtual void OnReadComplete(const MockRead& data) = 0; - protected: - virtual ~MockClientSocket() {} + virtual ~MockClientSocket(); void RunCallbackAsync(net::CompletionCallback* callback, int result); void RunCallback(net::CompletionCallback*, int result); @@ -609,13 +609,7 @@ class MockTCPClientSocket : public MockClientSocket { MockTCPClientSocket(const net::AddressList& addresses, net::NetLog* net_log, net::SocketDataProvider* socket); - // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback); - virtual void Disconnect(); - virtual bool IsConnected() const; - virtual bool IsConnectedAndIdle() const { return IsConnected(); } - virtual bool WasEverUsed() const { return was_used_to_convey_data_; } - virtual bool UsingTCPFastOpen() const { return false; } + net::AddressList addresses() const { return addresses_; } // Socket methods: virtual int Read(net::IOBuffer* buf, int buf_len, @@ -623,9 +617,16 @@ class MockTCPClientSocket : public MockClientSocket { virtual int Write(net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback); - virtual void OnReadComplete(const MockRead& data); + // ClientSocket methods: + virtual int Connect(net::CompletionCallback* callback); + virtual void Disconnect(); + virtual bool IsConnected() const; + virtual bool IsConnectedAndIdle() const; + virtual bool WasEverUsed() const; + virtual bool UsingTCPFastOpen() const; - net::AddressList addresses() const { return addresses_; } + // MockClientSocket: + virtual void OnReadComplete(const MockRead& data); private: int CompleteRead(); @@ -654,27 +655,30 @@ class DeterministicMockTCPClientSocket : public MockClientSocket, public: DeterministicMockTCPClientSocket(net::NetLog* net_log, net::DeterministicSocketData* data); + virtual ~DeterministicMockTCPClientSocket(); - // ClientSocket methods: - virtual int Connect(net::CompletionCallback* callback); - virtual void Disconnect(); - virtual bool IsConnected() const; - virtual bool IsConnectedAndIdle() const { return IsConnected(); } - virtual bool WasEverUsed() const { return was_used_to_convey_data_; } - virtual bool UsingTCPFastOpen() const { return false; } + bool write_pending() const { return write_pending_; } + bool read_pending() const { return read_pending_; } - // Socket methods: + void CompleteWrite(); + int CompleteRead(); + + // Socket: virtual int Write(net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback); virtual int Read(net::IOBuffer* buf, int buf_len, net::CompletionCallback* callback); - bool write_pending() const { return write_pending_; } - bool read_pending() const { return read_pending_; } + // ClientSocket: + virtual int Connect(net::CompletionCallback* callback); + virtual void Disconnect(); + virtual bool IsConnected() const; + virtual bool IsConnectedAndIdle() const; + virtual bool WasEverUsed() const; + virtual bool UsingTCPFastOpen() const; - void CompleteWrite(); - int CompleteRead(); - void OnReadComplete(const MockRead& data); + // MockClientSocket: + virtual void OnReadComplete(const MockRead& data); private: bool write_pending_; @@ -699,7 +703,13 @@ class MockSSLClientSocket : public MockClientSocket { const net::SSLConfig& ssl_config, SSLHostInfo* ssl_host_info, net::SSLSocketDataProvider* socket); - ~MockSSLClientSocket(); + virtual ~MockSSLClientSocket(); + + // Socket methods: + virtual int Read(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); + virtual int Write(net::IOBuffer* buf, int buf_len, + net::CompletionCallback* callback); // ClientSocket methods: virtual int Connect(net::CompletionCallback* callback); @@ -708,12 +718,6 @@ class MockSSLClientSocket : public MockClientSocket { virtual bool WasEverUsed() const; virtual bool UsingTCPFastOpen() const; - // Socket methods: - virtual int Read(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - virtual int Write(net::IOBuffer* buf, int buf_len, - net::CompletionCallback* callback); - // SSLClientSocket methods: virtual void GetSSLInfo(net::SSLInfo* ssl_info); virtual void GetSSLCertRequestInfo( @@ -723,7 +727,7 @@ class MockSSLClientSocket : public MockClientSocket { virtual bool set_was_npn_negotiated(bool negotiated); // This MockSocket does not implement the manual async IO feature. - virtual void OnReadComplete(const MockRead& data) { NOTIMPLEMENTED(); } + virtual void OnReadComplete(const MockRead& data); private: class ConnectCallback; @@ -878,6 +882,13 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory { // created. MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const; + SocketDataProviderArray<DeterministicSocketData>& mock_data() { + return mock_data_; + } + std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() { + return tcp_client_sockets_; + } + // ClientSocketFactory virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses, NetLog* net_log, @@ -890,13 +901,6 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory { CertVerifier* cert_verifier, DnsCertProvenanceChecker* dns_cert_checker); - SocketDataProviderArray<DeterministicSocketData>& mock_data() { - return mock_data_; - } - std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() { - return tcp_client_sockets_; - } - private: SocketDataProviderArray<DeterministicSocketData> mock_data_; SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; diff --git a/net/socket/socks_client_socket_pool.cc b/net/socket/socks_client_socket_pool.cc index 4db8c5b..42abb7b 100644 --- a/net/socket/socks_client_socket_pool.cc +++ b/net/socket/socks_client_socket_pool.cc @@ -75,11 +75,6 @@ LoadState SOCKSConnectJob::GetLoadState() const { } } -int SOCKSConnectJob::ConnectInternal() { - next_state_ = STATE_TCP_CONNECT; - return DoLoop(OK); -} - void SOCKSConnectJob::OnIOComplete(int result) { int rv = DoLoop(result); if (rv != ERR_IO_PENDING) @@ -163,6 +158,11 @@ int SOCKSConnectJob::DoSOCKSConnectComplete(int result) { return result; } +int SOCKSConnectJob::ConnectInternal() { + next_state_ = STATE_TCP_CONNECT; + return DoLoop(OK); +} + ConnectJob* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, diff --git a/net/socket/socks_client_socket_pool.h b/net/socket/socks_client_socket_pool.h index 3ba71546..5608c20 100644 --- a/net/socket/socks_client_socket_pool.h +++ b/net/socket/socks_client_socket_pool.h @@ -75,11 +75,6 @@ class SOCKSConnectJob : public ConnectJob { STATE_NONE, }; - // Begins the tcp connection and the SOCKS handshake. Returns OK on success - // and ERR_IO_PENDING if it cannot immediately service the request. - // Otherwise, it returns a net error code. - virtual int ConnectInternal(); - void OnIOComplete(int result); // Runs the state transition loop. @@ -90,6 +85,11 @@ class SOCKSConnectJob : public ConnectJob { int DoSOCKSConnect(); int DoSOCKSConnectComplete(int result); + // Begins the tcp connection and the SOCKS handshake. Returns OK on success + // and ERR_IO_PENDING if it cannot immediately service the request. + // Otherwise, it returns a net error code. + virtual int ConnectInternal(); + scoped_refptr<SOCKSSocketParams> socks_params_; TCPClientSocketPool* const tcp_pool_; HostResolver* const resolver_; diff --git a/net/socket/ssl_client_socket_pool.cc b/net/socket/ssl_client_socket_pool.cc index b8ffca8..ff96212 100644 --- a/net/socket/ssl_client_socket_pool.cc +++ b/net/socket/ssl_client_socket_pool.cc @@ -121,24 +121,16 @@ LoadState SSLConnectJob::GetLoadState() const { } } -int SSLConnectJob::ConnectInternal() { - switch (params_->proxy()) { - case ProxyServer::SCHEME_DIRECT: - next_state_ = STATE_TCP_CONNECT; - break; - case ProxyServer::SCHEME_HTTP: - case ProxyServer::SCHEME_HTTPS: - next_state_ = STATE_TUNNEL_CONNECT; - break; - case ProxyServer::SCHEME_SOCKS4: - case ProxyServer::SCHEME_SOCKS5: - next_state_ = STATE_SOCKS_CONNECT; - break; - default: - NOTREACHED() << "unknown proxy type"; - break; +void SSLConnectJob::GetAdditionalErrorState(ClientSocketHandle * handle) { + // Headers in |error_response_info_| indicate a proxy tunnel setup + // problem. See DoTunnelConnectComplete. + if (error_response_info_.headers) { + handle->set_pending_http_proxy_connection( + transport_socket_handle_.release()); } - return DoLoop(OK); + handle->set_ssl_error_response_info(error_response_info_); + if (!ssl_connect_start_time_.is_null()) + handle->set_is_ssl_error(true); } void SSLConnectJob::OnIOComplete(int result) { @@ -276,18 +268,6 @@ int SSLConnectJob::DoTunnelConnectComplete(int result) { return result; } -void SSLConnectJob::GetAdditionalErrorState(ClientSocketHandle * handle) { - // Headers in |error_response_info_| indicate a proxy tunnel setup - // problem. See DoTunnelConnectComplete. - if (error_response_info_.headers) { - handle->set_pending_http_proxy_connection( - transport_socket_handle_.release()); - } - handle->set_ssl_error_response_info(error_response_info_); - if (!ssl_connect_start_time_.is_null()) - handle->set_is_ssl_error(true); -} - int SSLConnectJob::DoSSLConnect() { next_state_ = STATE_SSL_CONNECT_COMPLETE; // Reset the timeout to just the time allowed for the SSL handshake. @@ -361,15 +341,24 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { return result; } -ConnectJob* SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( - const std::string& group_name, - const PoolBase::Request& request, - ConnectJob::Delegate* delegate) const { - return new SSLConnectJob(group_name, request.params(), ConnectionTimeout(), - tcp_pool_, socks_pool_, http_proxy_pool_, - client_socket_factory_, host_resolver_, - cert_verifier_, dnsrr_resolver_, dns_cert_checker_, - ssl_host_info_factory_, delegate, net_log_); +int SSLConnectJob::ConnectInternal() { + switch (params_->proxy()) { + case ProxyServer::SCHEME_DIRECT: + next_state_ = STATE_TCP_CONNECT; + break; + case ProxyServer::SCHEME_HTTP: + case ProxyServer::SCHEME_HTTPS: + next_state_ = STATE_TUNNEL_CONNECT; + break; + case ProxyServer::SCHEME_SOCKS4: + case ProxyServer::SCHEME_SOCKS5: + next_state_ = STATE_SOCKS_CONNECT; + break; + default: + NOTREACHED() << "unknown proxy type"; + break; + } + return DoLoop(OK); } SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( @@ -448,6 +437,17 @@ SSLClientSocketPool::~SSLClientSocketPool() { ssl_config_service_->RemoveObserver(this); } +ConnectJob* SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( + const std::string& group_name, + const PoolBase::Request& request, + ConnectJob::Delegate* delegate) const { + return new SSLConnectJob(group_name, request.params(), ConnectionTimeout(), + tcp_pool_, socks_pool_, http_proxy_pool_, + client_socket_factory_, host_resolver_, + cert_verifier_, dnsrr_resolver_, dns_cert_checker_, + ssl_host_info_factory_, delegate, net_log_); +} + int SSLClientSocketPool::RequestSocket(const std::string& group_name, const void* socket_params, RequestPriority priority, @@ -504,10 +504,6 @@ LoadState SSLClientSocketPool::GetLoadState( return base_.GetLoadState(group_name, handle); } -void SSLClientSocketPool::OnSSLConfigChanged() { - Flush(); -} - DictionaryValue* SSLClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, @@ -543,4 +539,8 @@ ClientSocketPoolHistograms* SSLClientSocketPool::histograms() const { return base_.histograms(); } +void SSLClientSocketPool::OnSSLConfigChanged() { + Flush(); +} + } // namespace net diff --git a/net/socket/ssl_client_socket_pool.h b/net/socket/ssl_client_socket_pool.h index d9d9594..41cf2a7 100644 --- a/net/socket/ssl_client_socket_pool.h +++ b/net/socket/ssl_client_socket_pool.h @@ -122,11 +122,6 @@ class SSLConnectJob : public ConnectJob { STATE_NONE, }; - // Starts the SSL connection process. Returns OK on success and - // ERR_IO_PENDING if it cannot immediately service the request. - // Otherwise, it returns a net error code. - virtual int ConnectInternal(); - void OnIOComplete(int result); // Runs the state transition loop. @@ -141,6 +136,11 @@ class SSLConnectJob : public ConnectJob { int DoSSLConnect(); int DoSSLConnectComplete(int result); + // Starts the SSL connection process. Returns OK on success and + // ERR_IO_PENDING if it cannot immediately service the request. + // Otherwise, it returns a net error code. + virtual int ConnectInternal(); + scoped_refptr<SSLSocketParams> params_; TCPClientSocketPool* const tcp_pool_; SOCKSClientSocketPool* const socks_pool_; diff --git a/net/spdy/spdy_session.h b/net/spdy/spdy_session.h index 6268a4f..3afa7c0 100644 --- a/net/spdy/spdy_session.h +++ b/net/spdy/spdy_session.h @@ -200,28 +200,19 @@ class SpdySession : public base::RefCounted<SpdySession>, friend class base::RefCounted<SpdySession>; FRIEND_TEST_ALL_PREFIXES(SpdySessionTest, GetActivePushStream); - enum State { - IDLE, - CONNECTING, - CONNECTED, - CLOSED - }; - - enum { kDefaultMaxConcurrentStreams = 10 }; - struct PendingCreateStream { - const GURL* url; - RequestPriority priority; - scoped_refptr<SpdyStream>* spdy_stream; - const BoundNetLog* stream_net_log; - CompletionCallback* callback; - PendingCreateStream(const GURL& url, RequestPriority priority, scoped_refptr<SpdyStream>* spdy_stream, const BoundNetLog& stream_net_log, CompletionCallback* callback) : url(&url), priority(priority), spdy_stream(spdy_stream), stream_net_log(&stream_net_log), callback(callback) { } + + const GURL* url; + RequestPriority priority; + scoped_refptr<SpdyStream>* spdy_stream; + const BoundNetLog* stream_net_log; + CompletionCallback* callback; }; typedef std::queue<PendingCreateStream, std::list< PendingCreateStream> > PendingCreateStreamQueue; @@ -242,6 +233,15 @@ class SpdySession : public base::RefCounted<SpdySession>, typedef std::map<const scoped_refptr<SpdyStream>*, CallbackResultPair> PendingCallbackMap; + enum State { + IDLE, + CONNECTING, + CONNECTED, + CLOSED + }; + + enum { kDefaultMaxConcurrentStreams = 10 }; + virtual ~SpdySession(); void ProcessPendingCreateStreams(); @@ -251,13 +251,6 @@ class SpdySession : public base::RefCounted<SpdySession>, scoped_refptr<SpdyStream>* spdy_stream, const BoundNetLog& stream_net_log); - // SpdyFramerVisitorInterface - virtual void OnError(spdy::SpdyFramer*); - virtual void OnStreamFrameData(spdy::SpdyStreamId stream_id, - const char* data, - size_t len); - virtual void OnControl(const spdy::SpdyControlFrame* frame); - // Control frame handlers. void OnSyn(const spdy::SpdySynStreamControlFrame& frame, const linked_ptr<spdy::SpdyHeaderBlock>& headers); @@ -325,6 +318,13 @@ class SpdySession : public base::RefCounted<SpdySession>, // can be deferred to the MessageLoop, so we avoid re-entrancy problems. void InvokeUserStreamCreationCallback(scoped_refptr<SpdyStream>* stream); + // SpdyFramerVisitorInterface: + virtual void OnError(spdy::SpdyFramer*); + virtual void OnStreamFrameData(spdy::SpdyStreamId stream_id, + const char* data, + size_t len); + virtual void OnControl(const spdy::SpdyControlFrame* frame); + // Callbacks for the Spdy session. CompletionCallbackImpl<SpdySession> read_callback_; CompletionCallbackImpl<SpdySession> write_callback_; @@ -439,12 +439,12 @@ class NetLogSpdySynParameter : public NetLog::EventParameters { spdy::SpdyStreamId id, spdy::SpdyStreamId associated_stream); - virtual Value* ToValue() const; - const linked_ptr<spdy::SpdyHeaderBlock>& GetHeaders() const { return headers_; } + virtual Value* ToValue() const; + private: virtual ~NetLogSpdySynParameter(); diff --git a/net/spdy/spdy_test_util.cc b/net/spdy/spdy_test_util.cc index 3a7f771..07a5980 100644 --- a/net/spdy/spdy_test_util.cc +++ b/net/spdy/spdy_test_util.cc @@ -852,6 +852,102 @@ int CombineFrames(const spdy::SpdyFrame** frames, int num_frames, return total_len; } +SpdySessionDependencies::SpdySessionDependencies() + : host_resolver(new MockHostResolver), + cert_verifier(new CertVerifier), + proxy_service(ProxyService::CreateDirect()), + ssl_config_service(new SSLConfigServiceDefaults), + socket_factory(new MockClientSocketFactory), + deterministic_socket_factory(new DeterministicMockClientSocketFactory), + http_auth_handler_factory( + HttpAuthHandlerFactory::CreateDefault(host_resolver.get())) { + // Note: The CancelledTransaction test does cleanup by running all + // tasks in the message loop (RunAllPending). Unfortunately, that + // doesn't clean up tasks on the host resolver thread; and + // TCPConnectJob is currently not cancellable. Using synchronous + // lookups allows the test to shutdown cleanly. Until we have + // cancellable TCPConnectJobs, use synchronous lookups. + host_resolver->set_synchronous_mode(true); +} + +SpdySessionDependencies::SpdySessionDependencies(ProxyService* proxy_service) + : host_resolver(new MockHostResolver), + cert_verifier(new CertVerifier), + proxy_service(proxy_service), + ssl_config_service(new SSLConfigServiceDefaults), + socket_factory(new MockClientSocketFactory), + deterministic_socket_factory(new DeterministicMockClientSocketFactory), + http_auth_handler_factory( + HttpAuthHandlerFactory::CreateDefault(host_resolver.get())) {} + +SpdySessionDependencies::~SpdySessionDependencies() {} + +// static +HttpNetworkSession* SpdySessionDependencies::SpdyCreateSession( + SpdySessionDependencies* session_deps) { + return new HttpNetworkSession(session_deps->host_resolver.get(), + session_deps->cert_verifier.get(), + NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, + NULL /* ssl_host_info_factory */, + session_deps->proxy_service, + session_deps->socket_factory.get(), + session_deps->ssl_config_service, + new SpdySessionPool(NULL), + session_deps->http_auth_handler_factory.get(), + NULL, + NULL); +} + +// static +HttpNetworkSession* SpdySessionDependencies::SpdyCreateSessionDeterministic( + SpdySessionDependencies* session_deps) { + return new HttpNetworkSession(session_deps->host_resolver.get(), + session_deps->cert_verifier.get(), + NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, + NULL /* ssl_host_info_factory */, + session_deps->proxy_service, + session_deps-> + deterministic_socket_factory.get(), + session_deps->ssl_config_service, + new SpdySessionPool(NULL), + session_deps->http_auth_handler_factory.get(), + NULL, + NULL); +} + +SpdyURLRequestContext::SpdyURLRequestContext() { + host_resolver_ = new MockHostResolver(); + cert_verifier_ = new CertVerifier; + proxy_service_ = ProxyService::CreateDirect(); + ssl_config_service_ = new SSLConfigServiceDefaults; + http_auth_handler_factory_ = HttpAuthHandlerFactory::CreateDefault( + host_resolver_); + http_transaction_factory_ = new HttpCache( + new HttpNetworkLayer(&socket_factory_, + host_resolver_, + cert_verifier_, + NULL /* dnsrr_resolver */, + NULL /* dns_cert_checker */, + NULL /* ssl_host_info_factory */, + proxy_service_, + ssl_config_service_, + new SpdySessionPool(NULL), + http_auth_handler_factory_, + network_delegate_, + NULL), + NULL /* net_log */, + HttpCache::DefaultBackend::InMemory(0)); +} + +SpdyURLRequestContext::~SpdyURLRequestContext() { + delete http_transaction_factory_; + delete http_auth_handler_factory_; + delete cert_verifier_; + delete host_resolver_; +} + const SpdyHeaderInfo make_spdy_header(spdy::SpdyControlType type) { const SpdyHeaderInfo kHeader = { type, // Kind = Syn diff --git a/net/spdy/spdy_test_util.h b/net/spdy/spdy_test_util.h index 698d511..8839514 100644 --- a/net/spdy/spdy_test_util.h +++ b/net/spdy/spdy_test_util.h @@ -326,34 +326,17 @@ int CombineFrames(const spdy::SpdyFrame** frames, int num_frames, class SpdySessionDependencies { public: // Default set of dependencies -- "null" proxy service. - SpdySessionDependencies() - : host_resolver(new MockHostResolver), - cert_verifier(new CertVerifier), - proxy_service(ProxyService::CreateDirect()), - ssl_config_service(new SSLConfigServiceDefaults), - socket_factory(new MockClientSocketFactory), - deterministic_socket_factory(new DeterministicMockClientSocketFactory), - http_auth_handler_factory( - HttpAuthHandlerFactory::CreateDefault(host_resolver.get())) { - // Note: The CancelledTransaction test does cleanup by running all - // tasks in the message loop (RunAllPending). Unfortunately, that - // doesn't clean up tasks on the host resolver thread; and - // TCPConnectJob is currently not cancellable. Using synchronous - // lookups allows the test to shutdown cleanly. Until we have - // cancellable TCPConnectJobs, use synchronous lookups. - host_resolver->set_synchronous_mode(true); - } + SpdySessionDependencies(); // Custom proxy service dependency. - explicit SpdySessionDependencies(ProxyService* proxy_service) - : host_resolver(new MockHostResolver), - cert_verifier(new CertVerifier), - proxy_service(proxy_service), - ssl_config_service(new SSLConfigServiceDefaults), - socket_factory(new MockClientSocketFactory), - deterministic_socket_factory(new DeterministicMockClientSocketFactory), - http_auth_handler_factory( - HttpAuthHandlerFactory::CreateDefault(host_resolver.get())) {} + explicit SpdySessionDependencies(ProxyService* proxy_service); + + ~SpdySessionDependencies(); + + static HttpNetworkSession* SpdyCreateSession( + SpdySessionDependencies* session_deps); + static HttpNetworkSession* SpdyCreateSessionDeterministic( + SpdySessionDependencies* session_deps); // NOTE: host_resolver must be ordered before http_auth_handler_factory. scoped_ptr<MockHostResolverBase> host_resolver; @@ -363,75 +346,16 @@ class SpdySessionDependencies { scoped_ptr<MockClientSocketFactory> socket_factory; scoped_ptr<DeterministicMockClientSocketFactory> deterministic_socket_factory; scoped_ptr<HttpAuthHandlerFactory> http_auth_handler_factory; - - static HttpNetworkSession* SpdyCreateSession( - SpdySessionDependencies* session_deps) { - return new HttpNetworkSession(session_deps->host_resolver.get(), - session_deps->cert_verifier.get(), - NULL /* dnsrr_resolver */, - NULL /* dns_cert_checker */, - NULL /* ssl_host_info_factory */, - session_deps->proxy_service, - session_deps->socket_factory.get(), - session_deps->ssl_config_service, - new SpdySessionPool(NULL), - session_deps->http_auth_handler_factory.get(), - NULL, - NULL); - } - static HttpNetworkSession* SpdyCreateSessionDeterministic( - SpdySessionDependencies* session_deps) { - return new HttpNetworkSession(session_deps->host_resolver.get(), - session_deps->cert_verifier.get(), - NULL /* dnsrr_resolver */, - NULL /* dns_cert_checker */, - NULL /* ssl_host_info_factory */, - session_deps->proxy_service, - session_deps-> - deterministic_socket_factory.get(), - session_deps->ssl_config_service, - new SpdySessionPool(NULL), - session_deps->http_auth_handler_factory.get(), - NULL, - NULL); - } }; class SpdyURLRequestContext : public URLRequestContext { public: - SpdyURLRequestContext() { - host_resolver_ = new MockHostResolver(); - cert_verifier_ = new CertVerifier; - proxy_service_ = ProxyService::CreateDirect(); - ssl_config_service_ = new SSLConfigServiceDefaults; - http_auth_handler_factory_ = HttpAuthHandlerFactory::CreateDefault( - host_resolver_); - http_transaction_factory_ = new HttpCache( - new HttpNetworkLayer(&socket_factory_, - host_resolver_, - cert_verifier_, - NULL /* dnsrr_resolver */, - NULL /* dns_cert_checker */, - NULL /* ssl_host_info_factory */, - proxy_service_, - ssl_config_service_, - new SpdySessionPool(NULL), - http_auth_handler_factory_, - network_delegate_, - NULL), - NULL /* net_log */, - HttpCache::DefaultBackend::InMemory(0)); - } + SpdyURLRequestContext(); MockClientSocketFactory& socket_factory() { return socket_factory_; } protected: - virtual ~SpdyURLRequestContext() { - delete http_transaction_factory_; - delete http_auth_handler_factory_; - delete cert_verifier_; - delete host_resolver_; - } + virtual ~SpdyURLRequestContext(); private: MockClientSocketFactory socket_factory_; diff --git a/net/test/test_server.cc b/net/test/test_server.cc index 9722dc1..14da7f4 100644 --- a/net/test/test_server.cc +++ b/net/test/test_server.cc @@ -28,6 +28,7 @@ #include "googleurl/src/gurl.h" #include "net/base/host_port_pair.h" #include "net/base/host_resolver.h" +#include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" #include "net/base/test_root_certs.h" #include "net/socket/tcp_client_socket.h" diff --git a/net/test/test_server.h b/net/test/test_server.h index 6d93fc8..c83b284 100644 --- a/net/test/test_server.h +++ b/net/test/test_server.h @@ -35,6 +35,8 @@ class AddressList; // that can provide various responses useful for testing. class TestServer { public: + typedef std::pair<std::string, std::string> StringPair; + enum Type { TYPE_FTP, TYPE_HTTP, @@ -126,7 +128,6 @@ class TestServer { const std::string& user, const std::string& password) const; - typedef std::pair<std::string, std::string> StringPair; static bool GetFilePathWithReplacements( const std::string& original_path, const std::vector<StringPair>& text_to_replace, diff --git a/net/tools/dump_cache/upgrade.cc b/net/tools/dump_cache/upgrade.cc index 7b86237..6d79b73 100644 --- a/net/tools/dump_cache/upgrade.cc +++ b/net/tools/dump_cache/upgrade.cc @@ -12,6 +12,7 @@ #include "base/win/scoped_handle.h" #include "googleurl/src/gurl.h" #include "net/base/io_buffer.h" +#include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" #include "net/disk_cache/backend_impl.h" #include "net/disk_cache/entry_impl.h" diff --git a/net/url_request/view_cache_helper_unittest.cc b/net/url_request/view_cache_helper_unittest.cc index b99a90c..5a656f2 100644 --- a/net/url_request/view_cache_helper_unittest.cc +++ b/net/url_request/view_cache_helper_unittest.cc @@ -5,6 +5,7 @@ #include "net/url_request/view_cache_helper.h" #include "base/pickle.h" +#include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" #include "net/disk_cache/disk_cache.h" #include "net/http/http_cache.h" diff --git a/webkit/database/database_tracker_unittest.cc b/webkit/database/database_tracker_unittest.cc index 2afdf5e..641c207 100644 --- a/webkit/database/database_tracker_unittest.cc +++ b/webkit/database/database_tracker_unittest.cc @@ -8,6 +8,7 @@ #include "base/scoped_temp_dir.h" #include "base/time.h" #include "base/utf_string_conversions.h" +#include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" #include "testing/gtest/include/gtest/gtest.h" #include "webkit/database/database_tracker.h" |