diff options
-rw-r--r-- | net/socket_stream/socket_stream.cc | 16 | ||||
-rw-r--r-- | net/socket_stream/socket_stream.h | 2 | ||||
-rw-r--r-- | net/socket_stream/socket_stream_unittest.cc | 132 |
3 files changed, 147 insertions, 3 deletions
diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc index b2f7a2c..7c1369c 100644 --- a/net/socket_stream/socket_stream.cc +++ b/net/socket_stream/socket_stream.cc @@ -59,6 +59,7 @@ SocketStream::SocketStream(const GURL& url, Delegate* delegate) current_write_buf_(NULL), write_buf_offset_(0), write_buf_size_(0), + closing_(false), metrics_(new SocketStreamMetrics(url)) { DCHECK(MessageLoop::current()) << "The current MessageLoop must exist"; @@ -175,9 +176,7 @@ void SocketStream::Close() { // of AddRef() and Release() in Connect() and Finish(), respectively. if (next_state_ == STATE_NONE) return; - if (socket_.get() && socket_->IsConnected()) - socket_->Disconnect(); - next_state_ = STATE_CLOSE; + closing_ = true; // Close asynchronously, so that delegate won't be called // back before returning Close(). MessageLoop::current()->PostTask( @@ -215,6 +214,8 @@ void SocketStream::DetachDelegate() { return; delegate_ = NULL; net_log_.AddEvent(NetLog::TYPE_CANCELLED, NULL); + // We don't need to send pending data when client detach the delegate. + pending_write_bufs_.clear(); Close(); } @@ -809,6 +810,15 @@ int SocketStream::DoReadWrite(int result) { return ERR_CONNECTION_CLOSED; } + // If client has requested close(), and there's nothing to write, then + // let's close the socket. + // We don't care about receiving data after the socket is closed. + if (closing_ && !write_buf_ && pending_write_bufs_.empty()) { + socket_->Disconnect(); + next_state_ = STATE_CLOSE; + return OK; + } + next_state_ = STATE_READ_WRITE; if (!read_buf_) { diff --git a/net/socket_stream/socket_stream.h b/net/socket_stream/socket_stream.h index b8fd0bf..b56d08e 100644 --- a/net/socket_stream/socket_stream.h +++ b/net/socket_stream/socket_stream.h @@ -318,6 +318,8 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { int write_buf_size_; PendingDataQueue pending_write_bufs_; + bool closing_; + scoped_ptr<SocketStreamMetrics> metrics_; DISALLOW_COPY_AND_ASSIGN(SocketStream); diff --git a/net/socket_stream/socket_stream_unittest.cc b/net/socket_stream/socket_stream_unittest.cc index b3bdbbe..d32654a 100644 --- a/net/socket_stream/socket_stream_unittest.cc +++ b/net/socket_stream/socket_stream_unittest.cc @@ -147,8 +147,140 @@ class SocketStreamEventRecorder : public net::SocketStream::Delegate { namespace net { class SocketStreamTest : public PlatformTest { + public: + virtual ~SocketStreamTest() {} + virtual void SetUp() { + mock_socket_factory_.reset(); + handshake_request_ = kWebSocketHandshakeRequest; + handshake_response_ = kWebSocketHandshakeResponse; + } + virtual void TearDown() { + mock_socket_factory_.reset(); + } + + virtual void SetWebSocketHandshakeMessage( + const char* request, const char* response) { + handshake_request_ = request; + handshake_response_ = response; + } + virtual void AddWebSocketMessage(const std::string& message) { + messages_.push_back(message); + } + + virtual MockClientSocketFactory* GetMockClientSocketFactory() { + mock_socket_factory_.reset(new MockClientSocketFactory); + return mock_socket_factory_.get(); + } + + virtual void DoSendWebSocketHandshake(SocketStreamEvent* event) { + event->socket->SendData( + handshake_request_.data(), handshake_request_.size()); + } + + virtual void DoCloseFlushPendingWriteTest(SocketStreamEvent* event) { + // handshake response received. + for (size_t i = 0; i < messages_.size(); i++) { + std::vector<char> frame; + frame.push_back('\0'); + frame.insert(frame.end(), messages_[i].begin(), messages_[i].end()); + frame.push_back('\xff'); + EXPECT_TRUE(event->socket->SendData(&frame[0], frame.size())); + } + // Actual ClientSocket close must happen after all frames queued by + // SendData above are sent out. + event->socket->Close(); + } + + static const char* kWebSocketHandshakeRequest; + static const char* kWebSocketHandshakeResponse; + + private: + std::string handshake_request_; + std::string handshake_response_; + std::vector<std::string> messages_; + + scoped_ptr<MockClientSocketFactory> mock_socket_factory_; }; +const char* SocketStreamTest::kWebSocketHandshakeRequest = + "GET /demo HTTP/1.1\r\n" + "Host: example.com\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n" + "Sec-WebSocket-Protocol: sample\r\n" + "Upgrade: WebSocket\r\n" + "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" + "Origin: http://example.com\r\n" + "\r\n" + "^n:ds[4U"; + +const char* SocketStreamTest::kWebSocketHandshakeResponse = + "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Origin: http://example.com\r\n" + "Sec-WebSocket-Location: ws://example.com/demo\r\n" + "Sec-WebSocket-Protocol: sample\r\n" + "\r\n" + "8jKS'y:G*Co,Wxa-"; + +TEST_F(SocketStreamTest, CloseFlushPendingWrite) { + TestCompletionCallback callback; + + scoped_ptr<SocketStreamEventRecorder> delegate( + new SocketStreamEventRecorder(&callback)); + // Necessary for NewCallback. + SocketStreamTest* test = this; + delegate->SetOnConnected(NewCallback( + test, &SocketStreamTest::DoSendWebSocketHandshake)); + delegate->SetOnReceivedData(NewCallback( + test, &SocketStreamTest::DoCloseFlushPendingWriteTest)); + + scoped_refptr<SocketStream> socket_stream = + new SocketStream(GURL("ws://example.com/demo"), delegate.get()); + + socket_stream->set_context(new TestURLRequestContext()); + socket_stream->SetHostResolver(new MockHostResolver()); + + MockWrite data_writes[] = { + MockWrite(SocketStreamTest::kWebSocketHandshakeRequest), + MockWrite(true, "\0message1\xff", 10), + MockWrite(true, "\0message2\xff", 10) + }; + MockRead data_reads[] = { + MockRead(SocketStreamTest::kWebSocketHandshakeResponse), + // Server doesn't close the connection after handshake. + MockRead(true, ERR_IO_PENDING) + }; + AddWebSocketMessage("message1"); + AddWebSocketMessage("message2"); + + scoped_refptr<DelayedSocketData> data_provider( + new DelayedSocketData(1, + data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes))); + + MockClientSocketFactory* mock_socket_factory = + GetMockClientSocketFactory(); + mock_socket_factory->AddSocketDataProvider(data_provider.get()); + + socket_stream->SetClientSocketFactory(mock_socket_factory); + + socket_stream->Connect(); + + callback.WaitForResult(); + + const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); + EXPECT_EQ(6U, events.size()); + + EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[0].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[1].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[2].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[3].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[5].event_type); +} + TEST_F(SocketStreamTest, BasicAuthProxy) { MockClientSocketFactory mock_socket_factory; MockWrite data_writes1[] = { |