summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--net/socket_stream/socket_stream.cc16
-rw-r--r--net/socket_stream/socket_stream.h2
-rw-r--r--net/socket_stream/socket_stream_unittest.cc132
3 files changed, 147 insertions, 3 deletions
diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc
index b2f7a2c..7c1369c 100644
--- a/net/socket_stream/socket_stream.cc
+++ b/net/socket_stream/socket_stream.cc
@@ -59,6 +59,7 @@ SocketStream::SocketStream(const GURL& url, Delegate* delegate)
current_write_buf_(NULL),
write_buf_offset_(0),
write_buf_size_(0),
+ closing_(false),
metrics_(new SocketStreamMetrics(url)) {
DCHECK(MessageLoop::current()) <<
"The current MessageLoop must exist";
@@ -175,9 +176,7 @@ void SocketStream::Close() {
// of AddRef() and Release() in Connect() and Finish(), respectively.
if (next_state_ == STATE_NONE)
return;
- if (socket_.get() && socket_->IsConnected())
- socket_->Disconnect();
- next_state_ = STATE_CLOSE;
+ closing_ = true;
// Close asynchronously, so that delegate won't be called
// back before returning Close().
MessageLoop::current()->PostTask(
@@ -215,6 +214,8 @@ void SocketStream::DetachDelegate() {
return;
delegate_ = NULL;
net_log_.AddEvent(NetLog::TYPE_CANCELLED, NULL);
+ // We don't need to send pending data when client detach the delegate.
+ pending_write_bufs_.clear();
Close();
}
@@ -809,6 +810,15 @@ int SocketStream::DoReadWrite(int result) {
return ERR_CONNECTION_CLOSED;
}
+ // If client has requested close(), and there's nothing to write, then
+ // let's close the socket.
+ // We don't care about receiving data after the socket is closed.
+ if (closing_ && !write_buf_ && pending_write_bufs_.empty()) {
+ socket_->Disconnect();
+ next_state_ = STATE_CLOSE;
+ return OK;
+ }
+
next_state_ = STATE_READ_WRITE;
if (!read_buf_) {
diff --git a/net/socket_stream/socket_stream.h b/net/socket_stream/socket_stream.h
index b8fd0bf..b56d08e 100644
--- a/net/socket_stream/socket_stream.h
+++ b/net/socket_stream/socket_stream.h
@@ -318,6 +318,8 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> {
int write_buf_size_;
PendingDataQueue pending_write_bufs_;
+ bool closing_;
+
scoped_ptr<SocketStreamMetrics> metrics_;
DISALLOW_COPY_AND_ASSIGN(SocketStream);
diff --git a/net/socket_stream/socket_stream_unittest.cc b/net/socket_stream/socket_stream_unittest.cc
index b3bdbbe..d32654a 100644
--- a/net/socket_stream/socket_stream_unittest.cc
+++ b/net/socket_stream/socket_stream_unittest.cc
@@ -147,8 +147,140 @@ class SocketStreamEventRecorder : public net::SocketStream::Delegate {
namespace net {
class SocketStreamTest : public PlatformTest {
+ public:
+ virtual ~SocketStreamTest() {}
+ virtual void SetUp() {
+ mock_socket_factory_.reset();
+ handshake_request_ = kWebSocketHandshakeRequest;
+ handshake_response_ = kWebSocketHandshakeResponse;
+ }
+ virtual void TearDown() {
+ mock_socket_factory_.reset();
+ }
+
+ virtual void SetWebSocketHandshakeMessage(
+ const char* request, const char* response) {
+ handshake_request_ = request;
+ handshake_response_ = response;
+ }
+ virtual void AddWebSocketMessage(const std::string& message) {
+ messages_.push_back(message);
+ }
+
+ virtual MockClientSocketFactory* GetMockClientSocketFactory() {
+ mock_socket_factory_.reset(new MockClientSocketFactory);
+ return mock_socket_factory_.get();
+ }
+
+ virtual void DoSendWebSocketHandshake(SocketStreamEvent* event) {
+ event->socket->SendData(
+ handshake_request_.data(), handshake_request_.size());
+ }
+
+ virtual void DoCloseFlushPendingWriteTest(SocketStreamEvent* event) {
+ // handshake response received.
+ for (size_t i = 0; i < messages_.size(); i++) {
+ std::vector<char> frame;
+ frame.push_back('\0');
+ frame.insert(frame.end(), messages_[i].begin(), messages_[i].end());
+ frame.push_back('\xff');
+ EXPECT_TRUE(event->socket->SendData(&frame[0], frame.size()));
+ }
+ // Actual ClientSocket close must happen after all frames queued by
+ // SendData above are sent out.
+ event->socket->Close();
+ }
+
+ static const char* kWebSocketHandshakeRequest;
+ static const char* kWebSocketHandshakeResponse;
+
+ private:
+ std::string handshake_request_;
+ std::string handshake_response_;
+ std::vector<std::string> messages_;
+
+ scoped_ptr<MockClientSocketFactory> mock_socket_factory_;
};
+const char* SocketStreamTest::kWebSocketHandshakeRequest =
+ "GET /demo HTTP/1.1\r\n"
+ "Host: example.com\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n"
+ "Sec-WebSocket-Protocol: sample\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n"
+ "Origin: http://example.com\r\n"
+ "\r\n"
+ "^n:ds[4U";
+
+const char* SocketStreamTest::kWebSocketHandshakeResponse =
+ "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Origin: http://example.com\r\n"
+ "Sec-WebSocket-Location: ws://example.com/demo\r\n"
+ "Sec-WebSocket-Protocol: sample\r\n"
+ "\r\n"
+ "8jKS'y:G*Co,Wxa-";
+
+TEST_F(SocketStreamTest, CloseFlushPendingWrite) {
+ TestCompletionCallback callback;
+
+ scoped_ptr<SocketStreamEventRecorder> delegate(
+ new SocketStreamEventRecorder(&callback));
+ // Necessary for NewCallback.
+ SocketStreamTest* test = this;
+ delegate->SetOnConnected(NewCallback(
+ test, &SocketStreamTest::DoSendWebSocketHandshake));
+ delegate->SetOnReceivedData(NewCallback(
+ test, &SocketStreamTest::DoCloseFlushPendingWriteTest));
+
+ scoped_refptr<SocketStream> socket_stream =
+ new SocketStream(GURL("ws://example.com/demo"), delegate.get());
+
+ socket_stream->set_context(new TestURLRequestContext());
+ socket_stream->SetHostResolver(new MockHostResolver());
+
+ MockWrite data_writes[] = {
+ MockWrite(SocketStreamTest::kWebSocketHandshakeRequest),
+ MockWrite(true, "\0message1\xff", 10),
+ MockWrite(true, "\0message2\xff", 10)
+ };
+ MockRead data_reads[] = {
+ MockRead(SocketStreamTest::kWebSocketHandshakeResponse),
+ // Server doesn't close the connection after handshake.
+ MockRead(true, ERR_IO_PENDING)
+ };
+ AddWebSocketMessage("message1");
+ AddWebSocketMessage("message2");
+
+ scoped_refptr<DelayedSocketData> data_provider(
+ new DelayedSocketData(1,
+ data_reads, arraysize(data_reads),
+ data_writes, arraysize(data_writes)));
+
+ MockClientSocketFactory* mock_socket_factory =
+ GetMockClientSocketFactory();
+ mock_socket_factory->AddSocketDataProvider(data_provider.get());
+
+ socket_stream->SetClientSocketFactory(mock_socket_factory);
+
+ socket_stream->Connect();
+
+ callback.WaitForResult();
+
+ const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
+ EXPECT_EQ(6U, events.size());
+
+ EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[0].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[1].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[2].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[3].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[5].event_type);
+}
+
TEST_F(SocketStreamTest, BasicAuthProxy) {
MockClientSocketFactory mock_socket_factory;
MockWrite data_writes1[] = {