summaryrefslogtreecommitdiffstats
path: root/net/socket_stream
diff options
context:
space:
mode:
Diffstat (limited to 'net/socket_stream')
-rw-r--r--net/socket_stream/socket_stream.cc58
-rw-r--r--net/socket_stream/socket_stream.h7
-rw-r--r--net/socket_stream/socket_stream_unittest.cc111
3 files changed, 159 insertions, 17 deletions
diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc
index b49a6a1..4c5b13a 100644
--- a/net/socket_stream/socket_stream.cc
+++ b/net/socket_stream/socket_stream.cc
@@ -234,6 +234,10 @@ void SocketStream::DetachDelegate() {
Close();
}
+const ProxyServer& SocketStream::proxy_server() const {
+ return proxy_info_.proxy_server();
+}
+
void SocketStream::SetHostResolver(HostResolver* host_resolver) {
DCHECK(host_resolver);
host_resolver_ = host_resolver;
@@ -401,6 +405,12 @@ void SocketStream::DoLoop(int result) {
case STATE_RESOLVE_HOST_COMPLETE:
result = DoResolveHostComplete(result);
break;
+ case STATE_RESOLVE_PROTOCOL:
+ result = DoResolveProtocol(result);
+ break;
+ case STATE_RESOLVE_PROTOCOL_COMPLETE:
+ result = DoResolveProtocolComplete(result);
+ break;
case STATE_TCP_CONNECT:
result = DoTcpConnect(result);
break;
@@ -452,6 +462,8 @@ void SocketStream::DoLoop(int result) {
Finish(result);
return;
}
+ if (state == STATE_RESOLVE_PROTOCOL && result == ERR_PROTOCOL_SWITCHED)
+ continue;
// If the connection is not established yet and had actual errors,
// close the connection.
if (state != STATE_READ_WRITE && result < ERR_IO_PENDING) {
@@ -545,28 +557,40 @@ int SocketStream::DoResolveHost() {
}
int SocketStream::DoResolveHostComplete(int result) {
- if (result == OK && delegate_) {
- result = delegate_->OnStartOpenConnection(this, &io_callback_);
- if (result == ERR_PROTOCOL_SWITCHED) {
- next_state_ = STATE_CLOSE;
- metrics_->OnCountWireProtocolType(
- SocketStreamMetrics::WIRE_PROTOCOL_SPDY);
- } else {
- next_state_ = STATE_TCP_CONNECT;
- metrics_->OnCountWireProtocolType(
- SocketStreamMetrics::WIRE_PROTOCOL_WEBSOCKET);
- if (result == ERR_IO_PENDING)
- metrics_->OnWaitConnection();
- }
- } else {
+ if (result == OK && delegate_)
+ next_state_ = STATE_RESOLVE_PROTOCOL;
+ else
next_state_ = STATE_CLOSE;
- }
// TODO(ukai): if error occured, reconsider proxy after error.
return result;
}
-const ProxyServer& SocketStream::proxy_server() const {
- return proxy_info_.proxy_server();
+int SocketStream::DoResolveProtocol(int result) {
+ DCHECK_EQ(OK, result);
+ next_state_ = STATE_RESOLVE_PROTOCOL_COMPLETE;
+ result = delegate_->OnStartOpenConnection(this, &io_callback_);
+ if (result == ERR_IO_PENDING)
+ metrics_->OnWaitConnection();
+ else if (result != OK && result != ERR_PROTOCOL_SWITCHED)
+ next_state_ = STATE_CLOSE;
+ return result;
+}
+
+int SocketStream::DoResolveProtocolComplete(int result) {
+ DCHECK_NE(ERR_IO_PENDING, result);
+
+ if (result == ERR_PROTOCOL_SWITCHED) {
+ next_state_ = STATE_CLOSE;
+ metrics_->OnCountWireProtocolType(
+ SocketStreamMetrics::WIRE_PROTOCOL_SPDY);
+ } else if (result == OK) {
+ next_state_ = STATE_TCP_CONNECT;
+ metrics_->OnCountWireProtocolType(
+ SocketStreamMetrics::WIRE_PROTOCOL_WEBSOCKET);
+ } else {
+ next_state_ = STATE_CLOSE;
+ }
+ return result;
}
int SocketStream::DoTcpConnect(int result) {
diff --git a/net/socket_stream/socket_stream.h b/net/socket_stream/socket_stream.h
index 4fb9ea3..3cbea0f 100644
--- a/net/socket_stream/socket_stream.h
+++ b/net/socket_stream/socket_stream.h
@@ -175,6 +175,9 @@ class NET_API SocketStream : public base::RefCountedThreadSafe<SocketStream> {
Delegate* delegate_;
private:
+ FRIEND_TEST_ALL_PREFIXES(SocketStreamTest, IOPending);
+ FRIEND_TEST_ALL_PREFIXES(SocketStreamTest, SwitchAfterPending);
+
friend class WebSocketThrottleTest;
typedef std::map<const void*, linked_ptr<UserData> > UserDataMap;
@@ -215,6 +218,8 @@ class NET_API SocketStream : public base::RefCountedThreadSafe<SocketStream> {
STATE_RESOLVE_PROXY_COMPLETE,
STATE_RESOLVE_HOST,
STATE_RESOLVE_HOST_COMPLETE,
+ STATE_RESOLVE_PROTOCOL,
+ STATE_RESOLVE_PROTOCOL_COMPLETE,
STATE_TCP_CONNECT,
STATE_TCP_CONNECT_COMPLETE,
STATE_WRITE_TUNNEL_HEADERS,
@@ -264,6 +269,8 @@ class NET_API SocketStream : public base::RefCountedThreadSafe<SocketStream> {
int DoResolveProxyComplete(int result);
int DoResolveHost();
int DoResolveHostComplete(int result);
+ int DoResolveProtocol(int result);
+ int DoResolveProtocolComplete(int result);
int DoTcpConnect(int result);
int DoTcpConnectComplete(int result);
int DoWriteTunnelHeaders();
diff --git a/net/socket_stream/socket_stream_unittest.cc b/net/socket_stream/socket_stream_unittest.cc
index 7a4bb4a..a1ea8ac 100644
--- a/net/socket_stream/socket_stream_unittest.cc
+++ b/net/socket_stream/socket_stream_unittest.cc
@@ -75,6 +75,7 @@ class SocketStreamEventRecorder : public net::SocketStream::Delegate {
virtual int OnStartOpenConnection(net::SocketStream* socket,
net::CompletionCallback* callback) {
+ connection_callback_ = callback;
events_.push_back(
SocketStreamEvent(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
socket, 0, std::string(), NULL, callback));
@@ -138,6 +139,9 @@ class SocketStreamEventRecorder : public net::SocketStream::Delegate {
username_ = username;
password_ = password;
}
+ void CompleteConnection(int result) {
+ connection_callback_->Run(result);
+ }
const std::vector<SocketStreamEvent>& GetSeenEvents() const {
return events_;
@@ -152,6 +156,7 @@ class SocketStreamEventRecorder : public net::SocketStream::Delegate {
base::Callback<void(SocketStreamEvent*)> on_close_;
base::Callback<void(SocketStreamEvent*)> on_auth_required_;
net::CompletionCallback* callback_;
+ net::CompletionCallback* connection_callback_;
string16 username_;
string16 password_;
@@ -210,9 +215,17 @@ class SocketStreamTest : public PlatformTest {
return net::ERR_PROTOCOL_SWITCHED;
}
+ virtual int DoIOPending(SocketStreamEvent* event) {
+ io_callback_.Run(net::OK);
+ return net::ERR_IO_PENDING;
+ }
+
static const char kWebSocketHandshakeRequest[];
static const char kWebSocketHandshakeResponse[];
+ protected:
+ TestCompletionCallback io_callback_;
+
private:
std::string handshake_request_;
std::string handshake_response_;
@@ -373,6 +386,72 @@ TEST_F(SocketStreamTest, BasicAuthProxy) {
// TODO(eroman): Add back NetLogTest here...
}
+TEST_F(SocketStreamTest, IOPending) {
+ TestCompletionCallback callback;
+
+ scoped_ptr<SocketStreamEventRecorder> delegate(
+ new SocketStreamEventRecorder(&callback));
+ delegate->SetOnConnected(base::Bind(
+ &SocketStreamTest::DoSendWebSocketHandshake, base::Unretained(this)));
+ delegate->SetOnReceivedData(base::Bind(
+ &SocketStreamTest::DoCloseFlushPendingWriteTest,
+ base::Unretained(this)));
+ delegate->SetOnStartOpenConnection(base::Bind(
+ &SocketStreamTest::DoIOPending, base::Unretained(this)));
+
+ MockHostResolver host_resolver;
+
+ scoped_refptr<SocketStream> socket_stream(
+ new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
+
+ socket_stream->set_context(new TestURLRequestContext());
+ socket_stream->SetHostResolver(&host_resolver);
+
+ MockWrite data_writes[] = {
+ MockWrite(SocketStreamTest::kWebSocketHandshakeRequest),
+ MockWrite(true, "\0message1\xff", 10),
+ MockWrite(true, "\0message2\xff", 10)
+ };
+ MockRead data_reads[] = {
+ MockRead(SocketStreamTest::kWebSocketHandshakeResponse),
+ // Server doesn't close the connection after handshake.
+ MockRead(true, ERR_IO_PENDING)
+ };
+ AddWebSocketMessage("message1");
+ AddWebSocketMessage("message2");
+
+ scoped_refptr<DelayedSocketData> data_provider(
+ new DelayedSocketData(1,
+ data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes)));
+
+ MockClientSocketFactory* mock_socket_factory =
+ GetMockClientSocketFactory();
+ mock_socket_factory->AddSocketDataProvider(data_provider.get());
+
+ socket_stream->SetClientSocketFactory(mock_socket_factory);
+
+ socket_stream->Connect();
+ io_callback_.WaitForResult();
+ EXPECT_EQ(net::SocketStream::STATE_RESOLVE_PROTOCOL_COMPLETE,
+ socket_stream->next_state_);
+ delegate->CompleteConnection(net::OK);
+
+ EXPECT_EQ(net::OK, callback.WaitForResult());
+
+ const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
+ ASSERT_EQ(7U, events.size());
+
+ EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
+ events[0].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[2].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[3].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[5].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[6].event_type);
+}
+
TEST_F(SocketStreamTest, SwitchToSpdy) {
TestCompletionCallback callback;
@@ -391,6 +470,38 @@ TEST_F(SocketStreamTest, SwitchToSpdy) {
socket_stream->Connect();
+ EXPECT_EQ(net::OK, callback.WaitForResult());
+
+ const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
+ ASSERT_EQ(2U, events.size());
+
+ EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
+ events[0].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[1].event_type);
+}
+
+TEST_F(SocketStreamTest, SwitchAfterPending) {
+ TestCompletionCallback callback;
+
+ scoped_ptr<SocketStreamEventRecorder> delegate(
+ new SocketStreamEventRecorder(&callback));
+ delegate->SetOnStartOpenConnection(base::Bind(
+ &SocketStreamTest::DoIOPending, base::Unretained(this)));
+
+ MockHostResolver host_resolver;
+
+ scoped_refptr<SocketStream> socket_stream(
+ new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
+
+ socket_stream->set_context(new TestURLRequestContext());
+ socket_stream->SetHostResolver(&host_resolver);
+
+ socket_stream->Connect();
+ io_callback_.WaitForResult();
+ EXPECT_EQ(net::SocketStream::STATE_RESOLVE_PROTOCOL_COMPLETE,
+ socket_stream->next_state_);
+ delegate->CompleteConnection(net::ERR_PROTOCOL_SWITCHED);
+
int result = callback.WaitForResult();
const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();