diff options
Diffstat (limited to 'net/socket_stream')
-rw-r--r-- | net/socket_stream/socket_stream.cc | 58 | ||||
-rw-r--r-- | net/socket_stream/socket_stream.h | 7 | ||||
-rw-r--r-- | net/socket_stream/socket_stream_unittest.cc | 111 |
3 files changed, 159 insertions, 17 deletions
diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc index b49a6a1..4c5b13a 100644 --- a/net/socket_stream/socket_stream.cc +++ b/net/socket_stream/socket_stream.cc @@ -234,6 +234,10 @@ void SocketStream::DetachDelegate() { Close(); } +const ProxyServer& SocketStream::proxy_server() const { + return proxy_info_.proxy_server(); +} + void SocketStream::SetHostResolver(HostResolver* host_resolver) { DCHECK(host_resolver); host_resolver_ = host_resolver; @@ -401,6 +405,12 @@ void SocketStream::DoLoop(int result) { case STATE_RESOLVE_HOST_COMPLETE: result = DoResolveHostComplete(result); break; + case STATE_RESOLVE_PROTOCOL: + result = DoResolveProtocol(result); + break; + case STATE_RESOLVE_PROTOCOL_COMPLETE: + result = DoResolveProtocolComplete(result); + break; case STATE_TCP_CONNECT: result = DoTcpConnect(result); break; @@ -452,6 +462,8 @@ void SocketStream::DoLoop(int result) { Finish(result); return; } + if (state == STATE_RESOLVE_PROTOCOL && result == ERR_PROTOCOL_SWITCHED) + continue; // If the connection is not established yet and had actual errors, // close the connection. if (state != STATE_READ_WRITE && result < ERR_IO_PENDING) { @@ -545,28 +557,40 @@ int SocketStream::DoResolveHost() { } int SocketStream::DoResolveHostComplete(int result) { - if (result == OK && delegate_) { - result = delegate_->OnStartOpenConnection(this, &io_callback_); - if (result == ERR_PROTOCOL_SWITCHED) { - next_state_ = STATE_CLOSE; - metrics_->OnCountWireProtocolType( - SocketStreamMetrics::WIRE_PROTOCOL_SPDY); - } else { - next_state_ = STATE_TCP_CONNECT; - metrics_->OnCountWireProtocolType( - SocketStreamMetrics::WIRE_PROTOCOL_WEBSOCKET); - if (result == ERR_IO_PENDING) - metrics_->OnWaitConnection(); - } - } else { + if (result == OK && delegate_) + next_state_ = STATE_RESOLVE_PROTOCOL; + else next_state_ = STATE_CLOSE; - } // TODO(ukai): if error occured, reconsider proxy after error. return result; } -const ProxyServer& SocketStream::proxy_server() const { - return proxy_info_.proxy_server(); +int SocketStream::DoResolveProtocol(int result) { + DCHECK_EQ(OK, result); + next_state_ = STATE_RESOLVE_PROTOCOL_COMPLETE; + result = delegate_->OnStartOpenConnection(this, &io_callback_); + if (result == ERR_IO_PENDING) + metrics_->OnWaitConnection(); + else if (result != OK && result != ERR_PROTOCOL_SWITCHED) + next_state_ = STATE_CLOSE; + return result; +} + +int SocketStream::DoResolveProtocolComplete(int result) { + DCHECK_NE(ERR_IO_PENDING, result); + + if (result == ERR_PROTOCOL_SWITCHED) { + next_state_ = STATE_CLOSE; + metrics_->OnCountWireProtocolType( + SocketStreamMetrics::WIRE_PROTOCOL_SPDY); + } else if (result == OK) { + next_state_ = STATE_TCP_CONNECT; + metrics_->OnCountWireProtocolType( + SocketStreamMetrics::WIRE_PROTOCOL_WEBSOCKET); + } else { + next_state_ = STATE_CLOSE; + } + return result; } int SocketStream::DoTcpConnect(int result) { diff --git a/net/socket_stream/socket_stream.h b/net/socket_stream/socket_stream.h index 4fb9ea3..3cbea0f 100644 --- a/net/socket_stream/socket_stream.h +++ b/net/socket_stream/socket_stream.h @@ -175,6 +175,9 @@ class NET_API SocketStream : public base::RefCountedThreadSafe<SocketStream> { Delegate* delegate_; private: + FRIEND_TEST_ALL_PREFIXES(SocketStreamTest, IOPending); + FRIEND_TEST_ALL_PREFIXES(SocketStreamTest, SwitchAfterPending); + friend class WebSocketThrottleTest; typedef std::map<const void*, linked_ptr<UserData> > UserDataMap; @@ -215,6 +218,8 @@ class NET_API SocketStream : public base::RefCountedThreadSafe<SocketStream> { STATE_RESOLVE_PROXY_COMPLETE, STATE_RESOLVE_HOST, STATE_RESOLVE_HOST_COMPLETE, + STATE_RESOLVE_PROTOCOL, + STATE_RESOLVE_PROTOCOL_COMPLETE, STATE_TCP_CONNECT, STATE_TCP_CONNECT_COMPLETE, STATE_WRITE_TUNNEL_HEADERS, @@ -264,6 +269,8 @@ class NET_API SocketStream : public base::RefCountedThreadSafe<SocketStream> { int DoResolveProxyComplete(int result); int DoResolveHost(); int DoResolveHostComplete(int result); + int DoResolveProtocol(int result); + int DoResolveProtocolComplete(int result); int DoTcpConnect(int result); int DoTcpConnectComplete(int result); int DoWriteTunnelHeaders(); diff --git a/net/socket_stream/socket_stream_unittest.cc b/net/socket_stream/socket_stream_unittest.cc index 7a4bb4a..a1ea8ac 100644 --- a/net/socket_stream/socket_stream_unittest.cc +++ b/net/socket_stream/socket_stream_unittest.cc @@ -75,6 +75,7 @@ class SocketStreamEventRecorder : public net::SocketStream::Delegate { virtual int OnStartOpenConnection(net::SocketStream* socket, net::CompletionCallback* callback) { + connection_callback_ = callback; events_.push_back( SocketStreamEvent(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, socket, 0, std::string(), NULL, callback)); @@ -138,6 +139,9 @@ class SocketStreamEventRecorder : public net::SocketStream::Delegate { username_ = username; password_ = password; } + void CompleteConnection(int result) { + connection_callback_->Run(result); + } const std::vector<SocketStreamEvent>& GetSeenEvents() const { return events_; @@ -152,6 +156,7 @@ class SocketStreamEventRecorder : public net::SocketStream::Delegate { base::Callback<void(SocketStreamEvent*)> on_close_; base::Callback<void(SocketStreamEvent*)> on_auth_required_; net::CompletionCallback* callback_; + net::CompletionCallback* connection_callback_; string16 username_; string16 password_; @@ -210,9 +215,17 @@ class SocketStreamTest : public PlatformTest { return net::ERR_PROTOCOL_SWITCHED; } + virtual int DoIOPending(SocketStreamEvent* event) { + io_callback_.Run(net::OK); + return net::ERR_IO_PENDING; + } + static const char kWebSocketHandshakeRequest[]; static const char kWebSocketHandshakeResponse[]; + protected: + TestCompletionCallback io_callback_; + private: std::string handshake_request_; std::string handshake_response_; @@ -373,6 +386,72 @@ TEST_F(SocketStreamTest, BasicAuthProxy) { // TODO(eroman): Add back NetLogTest here... } +TEST_F(SocketStreamTest, IOPending) { + TestCompletionCallback callback; + + scoped_ptr<SocketStreamEventRecorder> delegate( + new SocketStreamEventRecorder(&callback)); + delegate->SetOnConnected(base::Bind( + &SocketStreamTest::DoSendWebSocketHandshake, base::Unretained(this))); + delegate->SetOnReceivedData(base::Bind( + &SocketStreamTest::DoCloseFlushPendingWriteTest, + base::Unretained(this))); + delegate->SetOnStartOpenConnection(base::Bind( + &SocketStreamTest::DoIOPending, base::Unretained(this))); + + MockHostResolver host_resolver; + + scoped_refptr<SocketStream> socket_stream( + new SocketStream(GURL("ws://example.com/demo"), delegate.get())); + + socket_stream->set_context(new TestURLRequestContext()); + socket_stream->SetHostResolver(&host_resolver); + + MockWrite data_writes[] = { + MockWrite(SocketStreamTest::kWebSocketHandshakeRequest), + MockWrite(true, "\0message1\xff", 10), + MockWrite(true, "\0message2\xff", 10) + }; + MockRead data_reads[] = { + MockRead(SocketStreamTest::kWebSocketHandshakeResponse), + // Server doesn't close the connection after handshake. + MockRead(true, ERR_IO_PENDING) + }; + AddWebSocketMessage("message1"); + AddWebSocketMessage("message2"); + + scoped_refptr<DelayedSocketData> data_provider( + new DelayedSocketData(1, + data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes))); + + MockClientSocketFactory* mock_socket_factory = + GetMockClientSocketFactory(); + mock_socket_factory->AddSocketDataProvider(data_provider.get()); + + socket_stream->SetClientSocketFactory(mock_socket_factory); + + socket_stream->Connect(); + io_callback_.WaitForResult(); + EXPECT_EQ(net::SocketStream::STATE_RESOLVE_PROTOCOL_COMPLETE, + socket_stream->next_state_); + delegate->CompleteConnection(net::OK); + + EXPECT_EQ(net::OK, callback.WaitForResult()); + + const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); + ASSERT_EQ(7U, events.size()); + + EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, + events[0].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[2].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[3].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[5].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[6].event_type); +} + TEST_F(SocketStreamTest, SwitchToSpdy) { TestCompletionCallback callback; @@ -391,6 +470,38 @@ TEST_F(SocketStreamTest, SwitchToSpdy) { socket_stream->Connect(); + EXPECT_EQ(net::OK, callback.WaitForResult()); + + const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); + ASSERT_EQ(2U, events.size()); + + EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, + events[0].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[1].event_type); +} + +TEST_F(SocketStreamTest, SwitchAfterPending) { + TestCompletionCallback callback; + + scoped_ptr<SocketStreamEventRecorder> delegate( + new SocketStreamEventRecorder(&callback)); + delegate->SetOnStartOpenConnection(base::Bind( + &SocketStreamTest::DoIOPending, base::Unretained(this))); + + MockHostResolver host_resolver; + + scoped_refptr<SocketStream> socket_stream( + new SocketStream(GURL("ws://example.com/demo"), delegate.get())); + + socket_stream->set_context(new TestURLRequestContext()); + socket_stream->SetHostResolver(&host_resolver); + + socket_stream->Connect(); + io_callback_.WaitForResult(); + EXPECT_EQ(net::SocketStream::STATE_RESOLVE_PROTOCOL_COMPLETE, + socket_stream->next_state_); + delegate->CompleteConnection(net::ERR_PROTOCOL_SWITCHED); + int result = callback.WaitForResult(); const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); |