diff options
-rw-r--r-- | net/http/http_network_transaction_unittest.cc | 178 | ||||
-rw-r--r-- | net/http/http_proxy_client_socket_pool.cc | 36 | ||||
-rw-r--r-- | net/http/http_proxy_client_socket_pool.h | 12 | ||||
-rw-r--r-- | net/socket/client_socket_handle.cc | 21 | ||||
-rw-r--r-- | net/socket/client_socket_handle.h | 5 | ||||
-rw-r--r-- | net/socket/client_socket_pool.h | 21 | ||||
-rw-r--r-- | net/socket/client_socket_pool_base.cc | 81 | ||||
-rw-r--r-- | net/socket/client_socket_pool_base.h | 47 | ||||
-rw-r--r-- | net/socket/client_socket_pool_base_unittest.cc | 185 | ||||
-rw-r--r-- | net/socket/socks_client_socket_pool.cc | 27 | ||||
-rw-r--r-- | net/socket/socks_client_socket_pool.h | 12 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_pool.cc | 33 | ||||
-rw-r--r-- | net/socket/ssl_client_socket_pool.h | 10 | ||||
-rw-r--r-- | net/socket/transport_client_socket_pool.cc | 12 | ||||
-rw-r--r-- | net/socket/transport_client_socket_pool.h | 6 | ||||
-rw-r--r-- | net/spdy/spdy_session.cc | 18 | ||||
-rw-r--r-- | net/spdy/spdy_session.h | 7 |
17 files changed, 663 insertions, 48 deletions
diff --git a/net/http/http_network_transaction_unittest.cc b/net/http/http_network_transaction_unittest.cc index 86ea90c..c3a9d11 100644 --- a/net/http/http_network_transaction_unittest.cc +++ b/net/http/http_network_transaction_unittest.cc @@ -9642,4 +9642,182 @@ TEST_F(HttpNetworkTransactionTest, SendPipelineEvictionFallback) { EXPECT_EQ("hello world", out.response_data); } +TEST_F(HttpNetworkTransactionTest, CloseOldSpdySessionToOpenNewOne) { + HttpStreamFactory::set_next_protos(SpdyNextProtos()); + int old_max_sockets_per_group = + ClientSocketPoolManager::max_sockets_per_group(); + int old_max_sockets_per_proxy_server = + ClientSocketPoolManager::max_sockets_per_proxy_server(); + int old_max_sockets_per_pool = + ClientSocketPoolManager::max_sockets_per_pool(); + ClientSocketPoolManager::set_max_sockets_per_group(1); + ClientSocketPoolManager::set_max_sockets_per_proxy_server(1); + ClientSocketPoolManager::set_max_sockets_per_pool(1); + + // Use two different hosts with different IPs so they don't get pooled. + SessionDependencies session_deps; + session_deps.host_resolver->rules()->AddRule("a.com", "10.0.0.1"); + session_deps.host_resolver->rules()->AddRule("b.com", "10.0.0.2"); + scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps)); + + SSLSocketDataProvider ssl1(true, OK); + ssl1.next_proto_status = SSLClientSocket::kNextProtoNegotiated; + ssl1.next_proto = "spdy/2"; + ssl1.was_npn_negotiated = true; + SSLSocketDataProvider ssl2(true, OK); + ssl2.next_proto_status = SSLClientSocket::kNextProtoNegotiated; + ssl2.next_proto = "spdy/2"; + ssl2.was_npn_negotiated = true; + session_deps.socket_factory.AddSSLSocketDataProvider(&ssl1); + session_deps.socket_factory.AddSSLSocketDataProvider(&ssl2); + + scoped_ptr<spdy::SpdyFrame> host1_req(ConstructSpdyGet( + "https://www.a.com", false, 1, LOWEST)); + MockWrite spdy1_writes[] = { + CreateMockWrite(*host1_req, 1), + }; + scoped_ptr<spdy::SpdyFrame> host1_resp(ConstructSpdyGetSynReply(NULL, 0, 1)); + scoped_ptr<spdy::SpdyFrame> host1_resp_body(ConstructSpdyBodyFrame(1, true)); + MockRead spdy1_reads[] = { + CreateMockRead(*host1_resp, 2), + CreateMockRead(*host1_resp_body, 3), + MockRead(true, ERR_IO_PENDING, 4), + }; + + scoped_refptr<OrderedSocketData> spdy1_data( + new OrderedSocketData( + spdy1_reads, arraysize(spdy1_reads), + spdy1_writes, arraysize(spdy1_writes))); + session_deps.socket_factory.AddSocketDataProvider(spdy1_data); + + scoped_ptr<spdy::SpdyFrame> host2_req(ConstructSpdyGet( + "https://www.b.com", false, 1, LOWEST)); + MockWrite spdy2_writes[] = { + CreateMockWrite(*host2_req, 1), + }; + scoped_ptr<spdy::SpdyFrame> host2_resp(ConstructSpdyGetSynReply(NULL, 0, 1)); + scoped_ptr<spdy::SpdyFrame> host2_resp_body(ConstructSpdyBodyFrame(1, true)); + MockRead spdy2_reads[] = { + CreateMockRead(*host2_resp, 2), + CreateMockRead(*host2_resp_body, 3), + MockRead(true, ERR_IO_PENDING, 4), + }; + + scoped_refptr<OrderedSocketData> spdy2_data( + new OrderedSocketData( + spdy2_reads, arraysize(spdy2_reads), + spdy2_writes, arraysize(spdy2_writes))); + session_deps.socket_factory.AddSocketDataProvider(spdy2_data); + + MockWrite http_write[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.a.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + + MockRead http_read[] = { + MockRead("HTTP/1.1 200 OK\r\n"), + MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), + MockRead("Content-Length: 6\r\n\r\n"), + MockRead("hello!"), + }; + + StaticSocketDataProvider http_data(http_read, arraysize(http_read), + http_write, arraysize(http_write)); + session_deps.socket_factory.AddSocketDataProvider(&http_data); + + HostPortPair host_port_pair_a("www.a.com", 443); + HostPortProxyPair host_port_proxy_pair_a( + host_port_pair_a, ProxyServer::Direct()); + EXPECT_FALSE( + session->spdy_session_pool()->HasSession(host_port_proxy_pair_a)); + + TestOldCompletionCallback callback; + HttpRequestInfo request1; + request1.method = "GET"; + request1.url = GURL("https://www.a.com/"); + request1.load_flags = 0; + scoped_ptr<HttpNetworkTransaction> trans(new HttpNetworkTransaction(session)); + + int rv = trans->Start(&request1, &callback, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(OK, callback.WaitForResult()); + + const HttpResponseInfo* response = trans->GetResponseInfo(); + ASSERT_TRUE(response != NULL); + ASSERT_TRUE(response->headers != NULL); + EXPECT_EQ("HTTP/1.1 200 OK", response->headers->GetStatusLine()); + EXPECT_TRUE(response->was_fetched_via_spdy); + EXPECT_TRUE(response->was_npn_negotiated); + + std::string response_data; + ASSERT_EQ(OK, ReadTransaction(trans.get(), &response_data)); + EXPECT_EQ("hello!", response_data); + trans.reset(); + EXPECT_TRUE( + session->spdy_session_pool()->HasSession(host_port_proxy_pair_a)); + + HostPortPair host_port_pair_b("www.b.com", 443); + HostPortProxyPair host_port_proxy_pair_b( + host_port_pair_b, ProxyServer::Direct()); + EXPECT_FALSE( + session->spdy_session_pool()->HasSession(host_port_proxy_pair_b)); + HttpRequestInfo request2; + request2.method = "GET"; + request2.url = GURL("https://www.b.com/"); + request2.load_flags = 0; + trans.reset(new HttpNetworkTransaction(session)); + + rv = trans->Start(&request2, &callback, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(OK, callback.WaitForResult()); + + response = trans->GetResponseInfo(); + ASSERT_TRUE(response != NULL); + ASSERT_TRUE(response->headers != NULL); + EXPECT_EQ("HTTP/1.1 200 OK", response->headers->GetStatusLine()); + EXPECT_TRUE(response->was_fetched_via_spdy); + EXPECT_TRUE(response->was_npn_negotiated); + ASSERT_EQ(OK, ReadTransaction(trans.get(), &response_data)); + EXPECT_EQ("hello!", response_data); + EXPECT_FALSE( + session->spdy_session_pool()->HasSession(host_port_proxy_pair_a)); + EXPECT_TRUE( + session->spdy_session_pool()->HasSession(host_port_proxy_pair_b)); + + HostPortPair host_port_pair_a1("www.a.com", 80); + HostPortProxyPair host_port_proxy_pair_a1( + host_port_pair_a1, ProxyServer::Direct()); + EXPECT_FALSE( + session->spdy_session_pool()->HasSession(host_port_proxy_pair_a1)); + HttpRequestInfo request3; + request3.method = "GET"; + request3.url = GURL("http://www.a.com/"); + request3.load_flags = 0; + trans.reset(new HttpNetworkTransaction(session)); + + rv = trans->Start(&request3, &callback, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(OK, callback.WaitForResult()); + + response = trans->GetResponseInfo(); + ASSERT_TRUE(response != NULL); + ASSERT_TRUE(response->headers != NULL); + EXPECT_EQ("HTTP/1.1 200 OK", response->headers->GetStatusLine()); + EXPECT_FALSE(response->was_fetched_via_spdy); + EXPECT_FALSE(response->was_npn_negotiated); + ASSERT_EQ(OK, ReadTransaction(trans.get(), &response_data)); + EXPECT_EQ("hello!", response_data); + EXPECT_FALSE( + session->spdy_session_pool()->HasSession(host_port_proxy_pair_a)); + EXPECT_FALSE( + session->spdy_session_pool()->HasSession(host_port_proxy_pair_b)); + + HttpStreamFactory::set_next_protos(std::vector<std::string>()); + ClientSocketPoolManager::set_max_sockets_per_pool(old_max_sockets_per_pool); + ClientSocketPoolManager::set_max_sockets_per_proxy_server( + old_max_sockets_per_proxy_server); + ClientSocketPoolManager::set_max_sockets_per_group(old_max_sockets_per_group); +} + } // namespace net diff --git a/net/http/http_proxy_client_socket_pool.cc b/net/http/http_proxy_client_socket_pool.cc index d4c080f..f7d3206 100644 --- a/net/http/http_proxy_client_socket_pool.cc +++ b/net/http/http_proxy_client_socket_pool.cc @@ -391,9 +391,21 @@ HttpProxyClientSocketPool::HttpProxyClientSocketPool( new HttpProxyConnectJobFactory(transport_pool, ssl_pool, host_resolver, - net_log)) {} + net_log)) { + // We should always have a |transport_pool_| except in unit tests. + if (transport_pool_) + transport_pool_->AddLayeredPool(this); + if (ssl_pool_) + ssl_pool_->AddLayeredPool(this); +} -HttpProxyClientSocketPool::~HttpProxyClientSocketPool() {} +HttpProxyClientSocketPool::~HttpProxyClientSocketPool() { + if (ssl_pool_) + ssl_pool_->RemoveLayeredPool(this); + // We should always have a |transport_pool_| except in unit tests. + if (transport_pool_) + transport_pool_->RemoveLayeredPool(this); +} int HttpProxyClientSocketPool::RequestSocket(const std::string& group_name, const void* socket_params, @@ -434,6 +446,12 @@ void HttpProxyClientSocketPool::Flush() { base_.Flush(); } +bool HttpProxyClientSocketPool::IsStalled() const { + return base_.IsStalled() || + (transport_pool_ && transport_pool_->IsStalled()) || + (ssl_pool_ && ssl_pool_->IsStalled()); +} + void HttpProxyClientSocketPool::CloseIdleSockets() { base_.CloseIdleSockets(); } @@ -452,6 +470,14 @@ LoadState HttpProxyClientSocketPool::GetLoadState( return base_.GetLoadState(group_name, handle); } +void HttpProxyClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { + base_.AddLayeredPool(layered_pool); +} + +void HttpProxyClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { + base_.RemoveLayeredPool(layered_pool); +} + DictionaryValue* HttpProxyClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, @@ -482,4 +508,10 @@ ClientSocketPoolHistograms* HttpProxyClientSocketPool::histograms() const { return base_.histograms(); } +bool HttpProxyClientSocketPool::CloseOneIdleConnection() { + if (base_.CloseOneIdleSocket()) + return true; + return base_.CloseOneIdleConnectionInLayeredPool(); +} + } // namespace net diff --git a/net/http/http_proxy_client_socket_pool.h b/net/http/http_proxy_client_socket_pool.h index 6fe83dd..31d21cb 100644 --- a/net/http/http_proxy_client_socket_pool.h +++ b/net/http/http_proxy_client_socket_pool.h @@ -167,7 +167,8 @@ class HttpProxyConnectJob : public ConnectJob { DISALLOW_COPY_AND_ASSIGN(HttpProxyConnectJob); }; -class NET_EXPORT_PRIVATE HttpProxyClientSocketPool : public ClientSocketPool { +class NET_EXPORT_PRIVATE HttpProxyClientSocketPool + : public ClientSocketPool, public LayeredPool { public: HttpProxyClientSocketPool( int max_sockets, @@ -202,6 +203,8 @@ class NET_EXPORT_PRIVATE HttpProxyClientSocketPool : public ClientSocketPool { virtual void Flush() OVERRIDE; + virtual bool IsStalled() const OVERRIDE; + virtual void CloseIdleSockets() OVERRIDE; virtual int IdleSocketCount() const OVERRIDE; @@ -213,6 +216,10 @@ class NET_EXPORT_PRIVATE HttpProxyClientSocketPool : public ClientSocketPool { const std::string& group_name, const ClientSocketHandle* handle) const OVERRIDE; + virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; + + virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; + virtual base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, @@ -222,6 +229,9 @@ class NET_EXPORT_PRIVATE HttpProxyClientSocketPool : public ClientSocketPool { virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; + // LayeredPool methods: + virtual bool CloseOneIdleConnection() OVERRIDE; + private: typedef ClientSocketPoolBase<HttpProxySocketParams> PoolBase; diff --git a/net/socket/client_socket_handle.cc b/net/socket/client_socket_handle.cc index 8309c3f..e093ec9 100644 --- a/net/socket/client_socket_handle.cc +++ b/net/socket/client_socket_handle.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Copyright (c) 2011 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. @@ -15,6 +15,8 @@ namespace net { ClientSocketHandle::ClientSocketHandle() : is_initialized_(false), + pool_(NULL), + layered_pool_(NULL), is_reused_(false), ALLOW_THIS_IN_INITIALIZER_LIST( callback_(this, &ClientSocketHandle::OnIOComplete)), @@ -49,6 +51,10 @@ void ClientSocketHandle::ResetInternal(bool cancel) { group_name_.clear(); is_reused_ = false; user_callback_ = NULL; + if (layered_pool_) { + pool_->RemoveLayeredPool(layered_pool_); + layered_pool_ = NULL; + } pool_ = NULL; idle_time_ = base::TimeDelta(); init_time_ = base::TimeTicks(); @@ -72,6 +78,19 @@ LoadState ClientSocketHandle::GetLoadState() const { return pool_->GetLoadState(group_name_, this); } +bool ClientSocketHandle::IsPoolStalled() const { + return pool_->IsStalled(); +} + +void ClientSocketHandle::AddLayeredPool(LayeredPool* layered_pool) { + CHECK(layered_pool); + CHECK(!layered_pool_); + if (pool_) { + pool_->AddLayeredPool(layered_pool); + layered_pool_ = layered_pool; + } +} + void ClientSocketHandle::OnIOComplete(int result) { OldCompletionCallback* callback = user_callback_; user_callback_ = NULL; diff --git a/net/socket/client_socket_handle.h b/net/socket/client_socket_handle.h index cab6ef6..23fc5cf 100644 --- a/net/socket/client_socket_handle.h +++ b/net/socket/client_socket_handle.h @@ -92,6 +92,10 @@ class NET_EXPORT ClientSocketHandle { // initialized the ClientSocketHandle. LoadState GetLoadState() const; + bool IsPoolStalled() const; + + void AddLayeredPool(LayeredPool* layered_pool); + // Returns true when Init() has completed successfully. bool is_initialized() const { return is_initialized_; } @@ -164,6 +168,7 @@ class NET_EXPORT ClientSocketHandle { bool is_initialized_; ClientSocketPool* pool_; + LayeredPool* layered_pool_; scoped_ptr<StreamSocket> socket_; std::string group_name_; bool is_reused_; diff --git a/net/socket/client_socket_pool.h b/net/socket/client_socket_pool.h index 5f618cd..315d436 100644 --- a/net/socket/client_socket_pool.h +++ b/net/socket/client_socket_pool.h @@ -29,6 +29,17 @@ class ClientSocketHandle; class ClientSocketPoolHistograms; class StreamSocket; +// ClientSocketPools are layered. This defines an interface for lower level +// socket pools to communicate with higher layer pools. +class LayeredPool { + public: + virtual ~LayeredPool() {} + + // Instructs the LayeredPool to close an idle connection. Return true if one + // was closed. + virtual bool CloseOneIdleConnection() = 0; +}; + // A ClientSocketPool is used to restrict the number of sockets open at a time. // It also maintains a list of idle persistent sockets. // @@ -110,6 +121,10 @@ class NET_EXPORT ClientSocketPool { // the pool. Does not flush any pools wrapped by |this|. virtual void Flush() = 0; + // Returns true if a new request may hit a per-pool (not per-host) max socket + // limit. + virtual bool IsStalled() const = 0; + // Called to close any idle connections held by the connection manager. virtual void CloseIdleSockets() = 0; @@ -123,6 +138,12 @@ class NET_EXPORT ClientSocketPool { virtual LoadState GetLoadState(const std::string& group_name, const ClientSocketHandle* handle) const = 0; + // Adds a LayeredPool on top of |this|. + virtual void AddLayeredPool(LayeredPool* layered_pool) = 0; + + // Removes a LayeredPool from |this|. + virtual void RemoveLayeredPool(LayeredPool* layered_pool) = 0; + // Retrieves information on the current state of the pool as a // DictionaryValue. Caller takes possession of the returned value. // If |include_nested_pools| is true, the states of any nested diff --git a/net/socket/client_socket_pool_base.cc b/net/socket/client_socket_pool_base.cc index 4c1600e..bb7eb06 100644 --- a/net/socket/client_socket_pool_base.cc +++ b/net/socket/client_socket_pool_base.cc @@ -207,6 +207,7 @@ ClientSocketPoolBaseHelper::~ClientSocketPoolBaseHelper() { DCHECK(group_map_.empty()); DCHECK(pending_callback_map_.empty()); DCHECK_EQ(0, connecting_socket_count_); + DCHECK(higher_layer_pools_.empty()); NetworkChangeNotifier::RemoveIPAddressObserver(this); } @@ -236,6 +237,18 @@ ClientSocketPoolBaseHelper::RemoveRequestFromQueue( return req; } +void ClientSocketPoolBaseHelper::AddLayeredPool(LayeredPool* pool) { + CHECK(pool); + CHECK(!ContainsKey(higher_layer_pools_, pool)); + higher_layer_pools_.insert(pool); +} + +void ClientSocketPoolBaseHelper::RemoveLayeredPool(LayeredPool* pool) { + CHECK(pool); + CHECK(ContainsKey(higher_layer_pools_, pool)); + higher_layer_pools_.erase(pool); +} + int ClientSocketPoolBaseHelper::RequestSocket( const std::string& group_name, const Request* request) { @@ -334,6 +347,10 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( // Can we make another active socket now? if (!group->HasAvailableSocketSlot(max_sockets_per_group_) && !request->ignore_limits()) { + // TODO(willchan): Consider whether or not we need to close a socket in a + // higher layered group. I don't think this makes sense since we would just + // reuse that socket then if we needed one and wouldn't make it down to this + // layer. request->net_log().AddEvent( NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS_PER_GROUP, NULL); return ERR_IO_PENDING; @@ -341,19 +358,28 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( if (ReachedMaxSocketsLimit() && !request->ignore_limits()) { if (idle_socket_count() > 0) { + // There's an idle socket in this pool. Either that's because there's + // still one in this group, but we got here due to preconnecting bypassing + // idle sockets, or because there's an idle socket in another group. bool closed = CloseOneIdleSocketExceptInGroup(group); if (preconnecting && !closed) return ERR_PRECONNECT_MAX_SOCKET_LIMIT; } else { - // We could check if we really have a stalled group here, but it requires - // a scan of all groups, so just flip a flag here, and do the check later. - request->net_log().AddEvent( - NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS, NULL); - return ERR_IO_PENDING; + do { + if (!CloseOneIdleConnectionInLayeredPool()) { + // We could check if we really have a stalled group here, but it + // requires a scan of all groups, so just flip a flag here, and do + // the check later. + request->net_log().AddEvent( + NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS, NULL); + return ERR_IO_PENDING; + } + } while (ReachedMaxSocketsLimit()); } } - // We couldn't find a socket to reuse, so allocate and connect a new one. + // We couldn't find a socket to reuse, and there's space to allocate one, + // so allocate and connect a new one. scoped_ptr<ConnectJob> connect_job( connect_job_factory_->NewConnectJob(group_name, *request, this)); @@ -790,18 +816,22 @@ void ClientSocketPoolBaseHelper::CheckForStalledSocketGroups() { // are not at the |max_sockets_per_group_| limit. Note: for requests with // the same priority, the winner is based on group hash ordering (and not // insertion order). -bool ClientSocketPoolBaseHelper::FindTopStalledGroup(Group** group, - std::string* group_name) { +bool ClientSocketPoolBaseHelper::FindTopStalledGroup( + Group** group, + std::string* group_name) const { + CHECK((group && group_name) || (!group && !group_name)); Group* top_group = NULL; const std::string* top_group_name = NULL; bool has_stalled_group = false; - for (GroupMap::iterator i = group_map_.begin(); + for (GroupMap::const_iterator i = group_map_.begin(); i != group_map_.end(); ++i) { Group* curr_group = i->second; const RequestQueue& queue = curr_group->pending_requests(); if (queue.empty()) continue; if (curr_group->IsStalled(max_sockets_per_group_)) { + if (!group) + return true; has_stalled_group = true; bool has_higher_priority = !top_group || curr_group->TopPendingPriority() < top_group->TopPendingPriority(); @@ -813,8 +843,11 @@ bool ClientSocketPoolBaseHelper::FindTopStalledGroup(Group** group, } if (top_group) { + CHECK(group); *group = top_group; *group_name = *top_group_name; + } else { + CHECK(!has_stalled_group); } return has_stalled_group; } @@ -887,6 +920,17 @@ void ClientSocketPoolBaseHelper::Flush() { AbortAllRequests(); } +bool ClientSocketPoolBaseHelper::IsStalled() const { + if ((handed_out_socket_count_ + connecting_socket_count_) < max_sockets_) + return false; + for (GroupMap::const_iterator it = group_map_.begin(); + it != group_map_.end(); it++) { + if (it->second->IsStalled(max_sockets_per_group_)) + return true; + } + return false; +} + void ClientSocketPoolBaseHelper::RemoveConnectJob(ConnectJob* job, Group* group) { CHECK_GT(connecting_socket_count_, 0); @@ -1023,8 +1067,10 @@ bool ClientSocketPoolBaseHelper::ReachedMaxSocketsLimit() const { return true; } -void ClientSocketPoolBaseHelper::CloseOneIdleSocket() { - CloseOneIdleSocketExceptInGroup(NULL); +bool ClientSocketPoolBaseHelper::CloseOneIdleSocket() { + if (idle_socket_count() == 0) + return false; + return CloseOneIdleSocketExceptInGroup(NULL); } bool ClientSocketPoolBaseHelper::CloseOneIdleSocketExceptInGroup( @@ -1048,9 +1094,18 @@ bool ClientSocketPoolBaseHelper::CloseOneIdleSocketExceptInGroup( } } - if (!exception_group) - LOG(DFATAL) << "No idle socket found to close!."; + return false; +} +bool ClientSocketPoolBaseHelper::CloseOneIdleConnectionInLayeredPool() { + // This pool doesn't have any idle sockets. It's possible that a pool at a + // higher layer is holding one of this sockets active, but it's actually idle. + // Query the higher layers. + for (std::set<LayeredPool*>::const_iterator it = higher_layer_pools_.begin(); + it != higher_layer_pools_.end(); ++it) { + if ((*it)->CloseOneIdleConnection()) + return true; + } return false; } diff --git a/net/socket/client_socket_pool_base.h b/net/socket/client_socket_pool_base.h index 5000f36..052c613 100644 --- a/net/socket/client_socket_pool_base.h +++ b/net/socket/client_socket_pool_base.h @@ -28,6 +28,7 @@ #include <map> #include <set> #include <string> +#include <vector> #include "base/basictypes.h" #include "base/memory/ref_counted.h" @@ -239,6 +240,11 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper virtual ~ClientSocketPoolBaseHelper(); + // Adds/Removes layered pools. It is expected in the destructor that no + // layered pools remain. + void AddLayeredPool(LayeredPool* pool); + void RemoveLayeredPool(LayeredPool* pool); + // See ClientSocketPool::RequestSocket for documentation on this function. // ClientSocketPoolBaseHelper takes ownership of |request|, which must be // heap allocated. @@ -261,6 +267,9 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // See ClientSocketPool::Flush for documentation on this function. void Flush(); + // See ClientSocketPool::IsStalled for documentation on this function. + bool IsStalled() const; + // See ClientSocketPool::CloseIdleSockets for documentation on this function. void CloseIdleSockets(); @@ -305,6 +314,16 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // sockets that timed out or can't be reused. Made public for testing. void CleanupIdleSockets(bool force); + // Closes one idle socket. Picks the first one encountered. + // TODO(willchan): Consider a better algorithm for doing this. Perhaps we + // should keep an ordered list of idle sockets, and close them in order. + // Requires maintaining more state. It's not clear if it's worth it since + // I'm not sure if we hit this situation often. + bool CloseOneIdleSocket(); + + // Checks layered pools to see if they can close an idle connection. + bool CloseOneIdleConnectionInLayeredPool(); + // See ClientSocketPool::GetInfoAsValue for documentation on this function. base::DictionaryValue* GetInfoAsValue(const std::string& name, const std::string& type) const; @@ -457,7 +476,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // at least one pending request. Returns true if any groups are stalled, and // if so, fills |group| and |group_name| with data of the stalled group // having highest priority. - bool FindTopStalledGroup(Group** group, std::string* group_name); + bool FindTopStalledGroup(Group** group, std::string* group_name) const; // Called when timer_ fires. This method scans the idle sockets removing // sockets that timed out or can't be reused. @@ -509,13 +528,6 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper static void LogBoundConnectJobToRequest( const NetLog::Source& connect_job_source, const Request* request); - // Closes one idle socket. Picks the first one encountered. - // TODO(willchan): Consider a better algorithm for doing this. Perhaps we - // should keep an ordered list of idle sockets, and close them in order. - // Requires maintaining more state. It's not clear if it's worth it since - // I'm not sure if we hit this situation often. - void CloseOneIdleSocket(); - // Same as CloseOneIdleSocket() except it won't close an idle socket in // |group|. If |group| is NULL, it is ignored. Returns true if it closed a // socket. @@ -580,6 +592,8 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // make sure that they are discarded rather than reused. int pool_generation_number_; + std::set<LayeredPool*> higher_layer_pools_; + ScopedRunnableMethodFactory<ClientSocketPoolBaseHelper> method_factory_; DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolBaseHelper); @@ -646,6 +660,13 @@ class ClientSocketPoolBase { virtual ~ClientSocketPoolBase() {} // These member functions simply forward to ClientSocketPoolBaseHelper. + void AddLayeredPool(LayeredPool* pool) { + helper_.AddLayeredPool(pool); + } + + void RemoveLayeredPool(LayeredPool* pool) { + helper_.RemoveLayeredPool(pool); + } // RequestSocket bundles up the parameters into a Request and then forwards to // ClientSocketPoolBaseHelper::RequestSocket(). @@ -690,6 +711,10 @@ class ClientSocketPoolBase { return helper_.ReleaseSocket(group_name, socket, id); } + void Flush() { helper_.Flush(); } + + bool IsStalled() const { return helper_.IsStalled(); } + void CloseIdleSockets() { return helper_.CloseIdleSockets(); } int idle_socket_count() const { return helper_.idle_socket_count(); } @@ -738,7 +763,11 @@ class ClientSocketPoolBase { void EnableConnectBackupJobs() { helper_.EnableConnectBackupJobs(); } - void Flush() { helper_.Flush(); } + bool CloseOneIdleSocket() { return helper_.CloseOneIdleSocket(); } + + bool CloseOneIdleConnectionInLayeredPool() { + return helper_.CloseOneIdleConnectionInLayeredPool(); + } private: // This adaptor class exists to bridge the diff --git a/net/socket/client_socket_pool_base_unittest.cc b/net/socket/client_socket_pool_base_unittest.cc index d318dd6..3f7e954 100644 --- a/net/socket/client_socket_pool_base_unittest.cc +++ b/net/socket/client_socket_pool_base_unittest.cc @@ -25,8 +25,12 @@ #include "net/socket/socket_test_util.h" #include "net/socket/ssl_host_info.h" #include "net/socket/stream_socket.h" +#include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" +using ::testing::Invoke; +using ::testing::Return; + namespace net { namespace { @@ -37,10 +41,18 @@ const net::RequestPriority kDefaultPriority = MEDIUM; class TestSocketParams : public base::RefCounted<TestSocketParams> { public: - bool ignore_limits() { return false; } + TestSocketParams() : ignore_limits_(false) {} + + void set_ignore_limits(bool ignore_limits) { + ignore_limits_ = ignore_limits; + } + bool ignore_limits() { return ignore_limits_; } + private: friend class base::RefCounted<TestSocketParams>; ~TestSocketParams() {} + + bool ignore_limits_; }; typedef ClientSocketPoolBase<TestSocketParams> TestClientSocketPoolBase; @@ -414,7 +426,7 @@ class TestClientSocketPool : public ClientSocketPool { net::RequestPriority priority, ClientSocketHandle* handle, OldCompletionCallback* callback, - const BoundNetLog& net_log) { + const BoundNetLog& net_log) OVERRIDE { const scoped_refptr<TestSocketParams>* casted_socket_params = static_cast<const scoped_refptr<TestSocketParams>*>(params); return base_.RequestSocket(group_name, *casted_socket_params, priority, @@ -424,7 +436,7 @@ class TestClientSocketPool : public ClientSocketPool { virtual void RequestSockets(const std::string& group_name, const void* params, int num_sockets, - const BoundNetLog& net_log) { + const BoundNetLog& net_log) OVERRIDE { const scoped_refptr<TestSocketParams>* casted_params = static_cast<const scoped_refptr<TestSocketParams>*>(params); @@ -433,47 +445,64 @@ class TestClientSocketPool : public ClientSocketPool { virtual void CancelRequest( const std::string& group_name, - ClientSocketHandle* handle) { + ClientSocketHandle* handle) OVERRIDE { base_.CancelRequest(group_name, handle); } virtual void ReleaseSocket( const std::string& group_name, StreamSocket* socket, - int id) { + int id) OVERRIDE { base_.ReleaseSocket(group_name, socket, id); } - virtual void Flush() { + virtual void Flush() OVERRIDE { base_.Flush(); } - virtual void CloseIdleSockets() { + virtual bool IsStalled() const OVERRIDE { + return base_.IsStalled(); + } + + virtual void CloseIdleSockets() OVERRIDE { base_.CloseIdleSockets(); } - virtual int IdleSocketCount() const { return base_.idle_socket_count(); } + virtual int IdleSocketCount() const OVERRIDE { + return base_.idle_socket_count(); + } - virtual int IdleSocketCountInGroup(const std::string& group_name) const { + virtual int IdleSocketCountInGroup( + const std::string& group_name) const OVERRIDE { return base_.IdleSocketCountInGroup(group_name); } - virtual LoadState GetLoadState(const std::string& group_name, - const ClientSocketHandle* handle) const { + virtual LoadState GetLoadState( + const std::string& group_name, + const ClientSocketHandle* handle) const OVERRIDE { return base_.GetLoadState(group_name, handle); } - virtual DictionaryValue* GetInfoAsValue(const std::string& name, - const std::string& type, - bool include_nested_pools) const { + virtual void AddLayeredPool(LayeredPool* pool) OVERRIDE { + base_.AddLayeredPool(pool); + } + + virtual void RemoveLayeredPool(LayeredPool* pool) OVERRIDE { + base_.RemoveLayeredPool(pool); + } + + virtual DictionaryValue* GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const OVERRIDE { return base_.GetInfoAsValue(name, type); } - virtual base::TimeDelta ConnectionTimeout() const { + virtual base::TimeDelta ConnectionTimeout() const OVERRIDE { return base_.ConnectionTimeout(); } - virtual ClientSocketPoolHistograms* histograms() const { + virtual ClientSocketPoolHistograms* histograms() const OVERRIDE { return base_.histograms(); } @@ -495,6 +524,10 @@ class TestClientSocketPool : public ClientSocketPool { void EnableConnectBackupJobs() { base_.EnableConnectBackupJobs(); } + bool CloseOneIdleConnectionInLayeredPool() { + return base_.CloseOneIdleConnectionInLayeredPool(); + } + private: TestClientSocketPoolBase base_; @@ -1160,6 +1193,7 @@ TEST_F(ClientSocketPoolBaseTest, WaitForStalledSocketAtSocketLimit) { ClientSocketHandle stalled_handle; TestOldCompletionCallback callback; { + EXPECT_FALSE(pool_->IsStalled()); ClientSocketHandle handles[kDefaultMaxSockets]; for (int i = 0; i < kDefaultMaxSockets; ++i) { TestOldCompletionCallback callback; @@ -1174,6 +1208,7 @@ TEST_F(ClientSocketPoolBaseTest, WaitForStalledSocketAtSocketLimit) { EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count()); EXPECT_EQ(0, pool_->IdleSocketCount()); + EXPECT_FALSE(pool_->IsStalled()); // Now we will hit the socket limit. EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo", @@ -1182,6 +1217,7 @@ TEST_F(ClientSocketPoolBaseTest, WaitForStalledSocketAtSocketLimit) { &callback, pool_.get(), BoundNetLog())); + EXPECT_TRUE(pool_->IsStalled()); // Dropping out of scope will close all handles and return them to idle. } @@ -2000,9 +2036,9 @@ TEST_F(ClientSocketPoolBaseTest, DisableCleanupTimer) { EXPECT_EQ(1, handle2.socket()->Write(NULL, 1, NULL)); handle2.Reset(); - // The idle socket timeout value was set to 10 milliseconds. Wait 20 + // The idle socket timeout value was set to 10 milliseconds. Wait 100 // milliseconds so the sockets timeout. - base::PlatformThread::Sleep(20); + base::PlatformThread::Sleep(100); MessageLoop::current()->RunAllPending(); ASSERT_EQ(2, pool_->IdleSocketCount()); @@ -3029,6 +3065,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsHitMaxSocketLimit) { ASSERT_TRUE(pool_->HasGroup("a")); EXPECT_EQ(kDefaultMaxSockets - 1, pool_->NumConnectJobsInGroup("a")); + EXPECT_FALSE(pool_->IsStalled()); ASSERT_FALSE(pool_->HasGroup("b")); @@ -3037,6 +3074,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsHitMaxSocketLimit) { ASSERT_TRUE(pool_->HasGroup("b")); EXPECT_EQ(1, pool_->NumConnectJobsInGroup("b")); + EXPECT_FALSE(pool_->IsStalled()); } TEST_F(ClientSocketPoolBaseTest, RequestSocketsCountIdleSockets) { @@ -3355,6 +3393,117 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectWithBackupJob) { EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a")); } +class MockLayeredPool : public LayeredPool { + public: + MockLayeredPool(TestClientSocketPool* pool, + const std::string& group_name) + : pool_(pool), + params_(new TestSocketParams), + group_name_(group_name) { + pool_->AddLayeredPool(this); + } + + ~MockLayeredPool() { + pool_->RemoveLayeredPool(this); + } + + int RequestSocket(TestClientSocketPool* pool) { + return handle_.Init(group_name_, params_, kDefaultPriority, &callback_, + pool, BoundNetLog()); + } + + int RequestSocketWithoutLimits(TestClientSocketPool* pool) { + params_->set_ignore_limits(true); + return handle_.Init(group_name_, params_, kDefaultPriority, &callback_, + pool, BoundNetLog()); + } + + bool ReleaseOneConnection() { + if (!handle_.is_initialized()) { + return false; + } + handle_.socket()->Disconnect(); + handle_.Reset(); + return true; + } + + MOCK_METHOD0(CloseOneIdleConnection, bool()); + + private: + TestClientSocketPool* const pool_; + scoped_refptr<TestSocketParams> params_; + ClientSocketHandle handle_; + TestOldCompletionCallback callback_; + const std::string group_name_; +}; + +TEST_F(ClientSocketPoolBaseTest, FailToCloseIdleSocketsNotHeldByLayeredPool) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + MockLayeredPool mock_layered_pool(pool_.get(), "foo"); + EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) + .WillOnce(Return(false)); + EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); + EXPECT_FALSE(pool_->CloseOneIdleConnectionInLayeredPool()); +} + +TEST_F(ClientSocketPoolBaseTest, ForciblyCloseIdleSocketsHeldByLayeredPool) { + CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + MockLayeredPool mock_layered_pool(pool_.get(), "foo"); + EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); + EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) + .WillOnce(Invoke(&mock_layered_pool, + &MockLayeredPool::ReleaseOneConnection)); + EXPECT_TRUE(pool_->CloseOneIdleConnectionInLayeredPool()); +} + +TEST_F(ClientSocketPoolBaseTest, CloseIdleSocketsHeldByLayeredPoolWhenNeeded) { + CreatePool(1, 1); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + MockLayeredPool mock_layered_pool(pool_.get(), "foo"); + EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); + EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) + .WillOnce(Invoke(&mock_layered_pool, + &MockLayeredPool::ReleaseOneConnection)); + ClientSocketHandle handle; + TestOldCompletionCallback callback; + EXPECT_EQ(OK, handle.Init("a", + params_, + kDefaultPriority, + &callback, + pool_.get(), + BoundNetLog())); +} + +TEST_F(ClientSocketPoolBaseTest, + CloseMultipleIdleSocketsHeldByLayeredPoolWhenNeeded) { + CreatePool(1, 1); + connect_job_factory_->set_job_type(TestConnectJob::kMockJob); + + MockLayeredPool mock_layered_pool1(pool_.get(), "foo"); + EXPECT_EQ(OK, mock_layered_pool1.RequestSocket(pool_.get())); + EXPECT_CALL(mock_layered_pool1, CloseOneIdleConnection()) + .WillRepeatedly(Invoke(&mock_layered_pool1, + &MockLayeredPool::ReleaseOneConnection)); + MockLayeredPool mock_layered_pool2(pool_.get(), "bar"); + EXPECT_EQ(OK, mock_layered_pool2.RequestSocketWithoutLimits(pool_.get())); + EXPECT_CALL(mock_layered_pool2, CloseOneIdleConnection()) + .WillRepeatedly(Invoke(&mock_layered_pool2, + &MockLayeredPool::ReleaseOneConnection)); + ClientSocketHandle handle; + TestOldCompletionCallback callback; + EXPECT_EQ(OK, handle.Init("a", + params_, + kDefaultPriority, + &callback, + pool_.get(), + BoundNetLog())); +} + } // namespace } // namespace net diff --git a/net/socket/socks_client_socket_pool.cc b/net/socket/socks_client_socket_pool.cc index 7f1967b..5fa52bf 100644 --- a/net/socket/socks_client_socket_pool.cc +++ b/net/socket/socks_client_socket_pool.cc @@ -198,9 +198,16 @@ SOCKSClientSocketPool::SOCKSClientSocketPool( new SOCKSConnectJobFactory(transport_pool, host_resolver, net_log)) { + // We should always have a |transport_pool_| except in unit tests. + if (transport_pool_) + transport_pool_->AddLayeredPool(this); } -SOCKSClientSocketPool::~SOCKSClientSocketPool() {} +SOCKSClientSocketPool::~SOCKSClientSocketPool() { + // We should always have a |transport_pool_| except in unit tests. + if (transport_pool_) + transport_pool_->RemoveLayeredPool(this); +} int SOCKSClientSocketPool::RequestSocket(const std::string& group_name, const void* socket_params, @@ -240,6 +247,10 @@ void SOCKSClientSocketPool::Flush() { base_.Flush(); } +bool SOCKSClientSocketPool::IsStalled() const { + return base_.IsStalled() || transport_pool_->IsStalled(); +} + void SOCKSClientSocketPool::CloseIdleSockets() { base_.CloseIdleSockets(); } @@ -258,6 +269,14 @@ LoadState SOCKSClientSocketPool::GetLoadState( return base_.GetLoadState(group_name, handle); } +void SOCKSClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { + base_.AddLayeredPool(layered_pool); +} + +void SOCKSClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { + base_.RemoveLayeredPool(layered_pool); +} + DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, @@ -281,4 +300,10 @@ ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const { return base_.histograms(); }; +bool SOCKSClientSocketPool::CloseOneIdleConnection() { + if (base_.CloseOneIdleSocket()) + return true; + return base_.CloseOneIdleConnectionInLayeredPool(); +} + } // namespace net diff --git a/net/socket/socks_client_socket_pool.h b/net/socket/socks_client_socket_pool.h index 34352b48..501f3bf 100644 --- a/net/socket/socks_client_socket_pool.h +++ b/net/socket/socks_client_socket_pool.h @@ -105,7 +105,8 @@ class SOCKSConnectJob : public ConnectJob { DISALLOW_COPY_AND_ASSIGN(SOCKSConnectJob); }; -class NET_EXPORT_PRIVATE SOCKSClientSocketPool : public ClientSocketPool { +class NET_EXPORT_PRIVATE SOCKSClientSocketPool + : public ClientSocketPool, public LayeredPool { public: SOCKSClientSocketPool( int max_sockets, @@ -139,6 +140,8 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool : public ClientSocketPool { virtual void Flush() OVERRIDE; + virtual bool IsStalled() const OVERRIDE; + virtual void CloseIdleSockets() OVERRIDE; virtual int IdleSocketCount() const OVERRIDE; @@ -150,6 +153,10 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool : public ClientSocketPool { const std::string& group_name, const ClientSocketHandle* handle) const OVERRIDE; + virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; + + virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; + virtual base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, @@ -159,6 +166,9 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool : public ClientSocketPool { virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; + // LayeredPool methods: + virtual bool CloseOneIdleConnection() OVERRIDE; + private: typedef ClientSocketPoolBase<SOCKSSocketParams> PoolBase; diff --git a/net/socket/ssl_client_socket_pool.cc b/net/socket/ssl_client_socket_pool.cc index 4fffac6..601ad73 100644 --- a/net/socket/ssl_client_socket_pool.cc +++ b/net/socket/ssl_client_socket_pool.cc @@ -477,9 +477,21 @@ SSLClientSocketPool::SSLClientSocketPool( ssl_config_service_(ssl_config_service) { if (ssl_config_service_) ssl_config_service_->AddObserver(this); + if (transport_pool_) + transport_pool_->AddLayeredPool(this); + if (socks_pool_) + socks_pool_->AddLayeredPool(this); + if (http_proxy_pool_) + http_proxy_pool_->AddLayeredPool(this); } SSLClientSocketPool::~SSLClientSocketPool() { + if (http_proxy_pool_) + http_proxy_pool_->RemoveLayeredPool(this); + if (socks_pool_) + socks_pool_->RemoveLayeredPool(this); + if (transport_pool_) + transport_pool_->RemoveLayeredPool(this); if (ssl_config_service_) ssl_config_service_->RemoveObserver(this); } @@ -532,6 +544,13 @@ void SSLClientSocketPool::Flush() { base_.Flush(); } +bool SSLClientSocketPool::IsStalled() const { + return base_.IsStalled() || + (transport_pool_ && transport_pool_->IsStalled()) || + (socks_pool_ && socks_pool_->IsStalled()) || + (http_proxy_pool_ && http_proxy_pool_->IsStalled()); +} + void SSLClientSocketPool::CloseIdleSockets() { base_.CloseIdleSockets(); } @@ -550,6 +569,14 @@ LoadState SSLClientSocketPool::GetLoadState( return base_.GetLoadState(group_name, handle); } +void SSLClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { + base_.AddLayeredPool(layered_pool); +} + +void SSLClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { + base_.RemoveLayeredPool(layered_pool); +} + DictionaryValue* SSLClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, @@ -589,4 +616,10 @@ void SSLClientSocketPool::OnSSLConfigChanged() { Flush(); } +bool SSLClientSocketPool::CloseOneIdleConnection() { + if (base_.CloseOneIdleSocket()) + return true; + return base_.CloseOneIdleConnectionInLayeredPool(); +} + } // namespace net diff --git a/net/socket/ssl_client_socket_pool.h b/net/socket/ssl_client_socket_pool.h index 3305a6f..ac5c22b7 100644 --- a/net/socket/ssl_client_socket_pool.h +++ b/net/socket/ssl_client_socket_pool.h @@ -167,6 +167,7 @@ class SSLConnectJob : public ConnectJob { class NET_EXPORT_PRIVATE SSLClientSocketPool : public ClientSocketPool, + public LayeredPool, public SSLConfigService::Observer { public: // Only the pools that will be used are required. i.e. if you never @@ -212,6 +213,8 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool virtual void Flush() OVERRIDE; + virtual bool IsStalled() const OVERRIDE; + virtual void CloseIdleSockets() OVERRIDE; virtual int IdleSocketCount() const OVERRIDE; @@ -223,6 +226,10 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool const std::string& group_name, const ClientSocketHandle* handle) const OVERRIDE; + virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; + + virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; + virtual base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, @@ -232,6 +239,9 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; + // LayeredPool methods: + virtual bool CloseOneIdleConnection() OVERRIDE; + private: typedef ClientSocketPoolBase<SSLSocketParams> PoolBase; diff --git a/net/socket/transport_client_socket_pool.cc b/net/socket/transport_client_socket_pool.cc index 38ac74e..a423ebd 100644 --- a/net/socket/transport_client_socket_pool.cc +++ b/net/socket/transport_client_socket_pool.cc @@ -450,6 +450,10 @@ void TransportClientSocketPool::Flush() { base_.Flush(); } +bool TransportClientSocketPool::IsStalled() const { + return base_.IsStalled(); +} + void TransportClientSocketPool::CloseIdleSockets() { base_.CloseIdleSockets(); } @@ -468,6 +472,14 @@ LoadState TransportClientSocketPool::GetLoadState( return base_.GetLoadState(group_name, handle); } +void TransportClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { + base_.AddLayeredPool(layered_pool); +} + +void TransportClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { + base_.RemoveLayeredPool(layered_pool); +} + DictionaryValue* TransportClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, diff --git a/net/socket/transport_client_socket_pool.h b/net/socket/transport_client_socket_pool.h index 055fae1..e3e6f02 100644 --- a/net/socket/transport_client_socket_pool.h +++ b/net/socket/transport_client_socket_pool.h @@ -165,6 +165,8 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { virtual void Flush() OVERRIDE; + virtual bool IsStalled() const OVERRIDE; + virtual void CloseIdleSockets() OVERRIDE; virtual int IdleSocketCount() const OVERRIDE; @@ -176,6 +178,10 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { const std::string& group_name, const ClientSocketHandle* handle) const OVERRIDE; + virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; + + virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; + virtual base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, diff --git a/net/spdy/spdy_session.cc b/net/spdy/spdy_session.cc index 6455167..686f4ae 100644 --- a/net/spdy/spdy_session.cc +++ b/net/spdy/spdy_session.cc @@ -336,6 +336,7 @@ net::Error SpdySession::InitializeWithSocket( state_ = CONNECTED; connection_.reset(connection); + connection_->AddLayeredPool(this); is_secure_ = is_secure; certificate_error_code_ = certificate_error_code; @@ -984,6 +985,15 @@ int SpdySession::GetLocalAddress(IPEndPoint* address) const { return connection_->socket()->GetLocalAddress(address); } +bool SpdySession::CloseOneIdleConnection() { + if (num_active_streams() == 0) { + // Should delete this. + RemoveFromPool(); + return true; + } + return false; +} + void SpdySession::ActivateStream(SpdyStream* stream) { const spdy::SpdyStreamId id = stream->stream_id(); DCHECK(!IsStreamActive(id)); @@ -1018,12 +1028,18 @@ void SpdySession::DeleteStream(spdy::SpdyStreamId id, int status) { if (stream) stream->OnClose(status); ProcessPendingCreateStreams(); + if (num_active_streams() == 0 && connection_->is_initialized() && + connection_->IsPoolStalled()) { + // Should delete this. + RemoveFromPool(); + } } void SpdySession::RemoveFromPool() { if (spdy_session_pool_) { - spdy_session_pool_->Remove(make_scoped_refptr(this)); + SpdySessionPool* pool = spdy_session_pool_; spdy_session_pool_ = NULL; + pool->Remove(make_scoped_refptr(this)); } } diff --git a/net/spdy/spdy_session.h b/net/spdy/spdy_session.h index f46e03f..ce02d5e 100644 --- a/net/spdy/spdy_session.h +++ b/net/spdy/spdy_session.h @@ -15,6 +15,7 @@ #include "base/gtest_prod_util.h" #include "base/memory/linked_ptr.h" #include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" #include "base/task.h" #include "net/base/io_buffer.h" #include "net/base/load_states.h" @@ -45,7 +46,8 @@ class SpdyStream; class SSLInfo; class NET_EXPORT SpdySession : public base::RefCounted<SpdySession>, - public spdy::SpdyFramerVisitorInterface { + public spdy::SpdyFramerVisitorInterface, + public LayeredPool { public: // Create a new SpdySession. // |host_port_proxy_pair| is the host/port that this session connects to, and @@ -228,6 +230,9 @@ class NET_EXPORT SpdySession : public base::RefCounted<SpdySession>, int GetPeerAddress(AddressList* address) const; int GetLocalAddress(IPEndPoint* address) const; + // LayeredPool methods: + virtual bool CloseOneIdleConnection() OVERRIDE; + private: friend class base::RefCounted<SpdySession>; // Allow tests to access our innards for testing purposes. |