diff options
author | willchan@chromium.org <willchan@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-04-17 23:39:37 +0000 |
---|---|---|
committer | willchan@chromium.org <willchan@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-04-17 23:39:37 +0000 |
commit | 340ba394b3b34db9c805418c5e54f37e3b264bd2 (patch) | |
tree | 326844a05569dcfe82625ae63d988e1da630074e | |
parent | f4dbdd6824257dcb211c47d4a950cc955af60953 (diff) | |
download | chromium_src-340ba394b3b34db9c805418c5e54f37e3b264bd2.zip chromium_src-340ba394b3b34db9c805418c5e54f37e3b264bd2.tar.gz chromium_src-340ba394b3b34db9c805418c5e54f37e3b264bd2.tar.bz2 |
Implement full duplex mode for windows tcp sockets.
Move tcp_client_socket.h stuff to tcp_client_socket_libevent.h and
tcp_client_socket_win.h.
Add tests.
Review URL: http://codereview.chromium.org/75030
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@13983 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r-- | net/base/tcp_client_socket.h | 122 | ||||
-rw-r--r-- | net/base/tcp_client_socket_libevent.cc | 105 | ||||
-rw-r--r-- | net/base/tcp_client_socket_libevent.h | 92 | ||||
-rw-r--r-- | net/base/tcp_client_socket_unittest.cc | 152 | ||||
-rw-r--r-- | net/base/tcp_client_socket_win.cc | 243 | ||||
-rw-r--r-- | net/base/tcp_client_socket_win.h | 114 | ||||
-rw-r--r-- | net/net.gyp | 2 |
7 files changed, 510 insertions, 320 deletions
diff --git a/net/base/tcp_client_socket.h b/net/base/tcp_client_socket.h index 729bbbc..820e586 100644 --- a/net/base/tcp_client_socket.h +++ b/net/base/tcp_client_socket.h @@ -1,4 +1,4 @@ -// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved. +// Copyright (c) 2006-2009 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. @@ -8,132 +8,20 @@ #include "build/build_config.h" #if defined(OS_WIN) -#include <ws2tcpip.h> -#include "base/object_watcher.h" +#include "net/base/tcp_client_socket_win.h" #elif defined(OS_POSIX) -struct event; // From libevent -#include <sys/socket.h> // for struct sockaddr -#define SOCKET int -#include "base/message_loop.h" +#include "net/base/tcp_client_socket_libevent.h" #endif -#include "net/base/address_list.h" -#include "net/base/client_socket.h" -#include "net/base/completion_callback.h" - namespace net { // A client socket that uses TCP as the transport layer. -// -// NOTE: The windows implementation supports half duplex only. -// Read and Write calls must not be in progress at the same time. -// The libevent implementation supports full duplex because that -// made it slightly easier to implement ssl. -class TCPClientSocket : public ClientSocket, -#if defined(OS_WIN) - public base::ObjectWatcher::Delegate -#elif defined(OS_POSIX) - public MessageLoopForIO::Watcher -#endif -{ - public: - // The IP address(es) and port number to connect to. The TCP socket will try - // each IP address in the list until it succeeds in establishing a - // connection. - explicit TCPClientSocket(const AddressList& addresses); - - ~TCPClientSocket(); - - // ClientSocket methods: - virtual int Connect(CompletionCallback* callback); - virtual void Disconnect(); - virtual bool IsConnected() const; - virtual bool IsConnectedAndIdle() const; - - // Socket methods: - // Multiple outstanding requests are not supported. - // Full duplex mode (reading and writing at the same time) is not supported - // on Windows (but is supported on Linux and Mac for ease of implementation - // of SSLClientSocket) - virtual int Read(char* buf, int buf_len, CompletionCallback* callback); - virtual int Write(const char* buf, int buf_len, CompletionCallback* callback); - -#if defined(OS_POSIX) - // Identical to posix system call of same name - // Needed by ssl_client_socket_nss - virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen); -#endif - - private: - SOCKET socket_; - - // The list of addresses we should try in order to establish a connection. - AddressList addresses_; - - // Where we are in above list, or NULL if all addrinfos have been tried. - const struct addrinfo* current_ai_; - #if defined(OS_WIN) - enum WaitState { - NOT_WAITING, - WAITING_CONNECT, - WAITING_READ, - WAITING_WRITE - }; - WaitState wait_state_; - - // base::ObjectWatcher::Delegate methods: - virtual void OnObjectSignaled(HANDLE object); - - // Waits for the (manual-reset) event object to become signaled and resets - // it. Called after a Winsock function succeeds synchronously - // - // Our testing shows that except in rare cases (when running inside QEMU), - // the event object is already signaled at this point, so we just call this - // method on the IO thread to avoid a context switch. - void WaitForAndResetEvent(); - - OVERLAPPED overlapped_; - WSABUF buffer_; - - base::ObjectWatcher watcher_; - - void DidCompleteIO(); +typedef TCPClientSocketWin TCPClientSocket; #elif defined(OS_POSIX) - // Whether we're currently waiting for connect() to complete - bool waiting_connect_; - - // The socket's libevent wrapper - MessageLoopForIO::FileDescriptorWatcher socket_watcher_; - - // Called by MessagePumpLibevent when the socket is ready to do I/O - void OnFileCanReadWithoutBlocking(int fd); - void OnFileCanWriteWithoutBlocking(int fd); - - // The buffer used by OnSocketReady to retry Read requests - char* buf_; - int buf_len_; - - // The buffer used by OnSocketReady to retry Write requests - const char* write_buf_; - int write_buf_len_; - - // External callback; called when write is complete. - CompletionCallback* write_callback_; - - void DoWriteCallback(int rv); - void DidCompleteRead(); - void DidCompleteWrite(); +typedef TCPClientSocketLibevent TCPClientSocket; #endif - // External callback; called when read (and on Windows, write) is complete. - CompletionCallback* callback_; - - int CreateSocket(const struct addrinfo* ai); - void DoCallback(int rv); - void DidCompleteConnect(); -}; - } // namespace net #endif // NET_BASE_TCP_CLIENT_SOCKET_H_ diff --git a/net/base/tcp_client_socket_libevent.cc b/net/base/tcp_client_socket_libevent.cc index 8d51dbf..112536d 100644 --- a/net/base/tcp_client_socket_libevent.cc +++ b/net/base/tcp_client_socket_libevent.cc @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/base/tcp_client_socket.h" +#include "net/base/tcp_client_socket_libevent.h" #include <errno.h> #include <fcntl.h> @@ -65,20 +65,20 @@ static int MapPosixError(int err) { //----------------------------------------------------------------------------- -TCPClientSocket::TCPClientSocket(const AddressList& addresses) - : socket_(kInvalidSocket), - addresses_(addresses), - current_ai_(addresses_.head()), - waiting_connect_(false), - write_callback_(NULL), - callback_(NULL) { +TCPClientSocketLibevent::TCPClientSocketLibevent(const AddressList& addresses) + : socket_(kInvalidSocket), + addresses_(addresses), + current_ai_(addresses_.head()), + waiting_connect_(false), + read_callback_(NULL), + write_callback_(NULL) { } -TCPClientSocket::~TCPClientSocket() { +TCPClientSocketLibevent::~TCPClientSocketLibevent() { Disconnect(); } -int TCPClientSocket::Connect(CompletionCallback* callback) { +int TCPClientSocketLibevent::Connect(CompletionCallback* callback) { // If already connected, then just return OK. if (socket_ != kInvalidSocket) return OK; @@ -122,11 +122,11 @@ int TCPClientSocket::Connect(CompletionCallback* callback) { } waiting_connect_ = true; - callback_ = callback; + read_callback_ = callback; return ERR_IO_PENDING; } -void TCPClientSocket::Disconnect() { +void TCPClientSocketLibevent::Disconnect() { if (socket_ == kInvalidSocket) return; @@ -141,7 +141,7 @@ void TCPClientSocket::Disconnect() { current_ai_ = addresses_.head(); } -bool TCPClientSocket::IsConnected() const { +bool TCPClientSocketLibevent::IsConnected() const { if (socket_ == kInvalidSocket || waiting_connect_) return false; @@ -156,7 +156,7 @@ bool TCPClientSocket::IsConnected() const { return true; } -bool TCPClientSocket::IsConnectedAndIdle() const { +bool TCPClientSocketLibevent::IsConnectedAndIdle() const { if (socket_ == kInvalidSocket || waiting_connect_) return false; @@ -172,12 +172,12 @@ bool TCPClientSocket::IsConnectedAndIdle() const { return true; } -int TCPClientSocket::Read(char* buf, - int buf_len, - CompletionCallback* callback) { - DCHECK(socket_ != kInvalidSocket); +int TCPClientSocketLibevent::Read(char* buf, + int buf_len, + CompletionCallback* callback) { + DCHECK_NE(socket_, kInvalidSocket); DCHECK(!waiting_connect_); - DCHECK(!callback_); + DCHECK(!read_callback_); // Synchronous operation not supported DCHECK(callback); DCHECK(buf_len > 0); @@ -194,27 +194,27 @@ int TCPClientSocket::Read(char* buf, } if (!MessageLoopForIO::current()->WatchFileDescriptor( - socket_, true, MessageLoopForIO::WATCH_READ, &socket_watcher_, this)) - { + socket_, true, MessageLoopForIO::WATCH_READ, + &socket_watcher_, this)) { DLOG(INFO) << "WatchFileDescriptor failed on read, errno " << errno; return MapPosixError(errno); } - buf_ = buf; - buf_len_ = buf_len; - callback_ = callback; + read_buf_ = buf; + read_buf_len_ = buf_len; + read_callback_ = callback; return ERR_IO_PENDING; } -int TCPClientSocket::Write(const char* buf, - int buf_len, - CompletionCallback* callback) { +int TCPClientSocketLibevent::Write(const char* buf, + int buf_len, + CompletionCallback* callback) { DCHECK(socket_ != kInvalidSocket); DCHECK(!waiting_connect_); DCHECK(!write_callback_); // Synchronous operation not supported DCHECK(callback); - DCHECK(buf_len > 0); + DCHECK_GT(buf_len, 0); TRACE_EVENT_BEGIN("socket.write", this, ""); int nwrite = write(socket_, buf, buf_len); @@ -226,8 +226,8 @@ int TCPClientSocket::Write(const char* buf, return MapPosixError(errno); if (!MessageLoopForIO::current()->WatchFileDescriptor( - socket_, true, MessageLoopForIO::WATCH_WRITE, &socket_watcher_, this)) - { + socket_, true, MessageLoopForIO::WATCH_WRITE, + &socket_watcher_, this)) { DLOG(INFO) << "WatchFileDescriptor failed on write, errno " << errno; return MapPosixError(errno); } @@ -239,7 +239,7 @@ int TCPClientSocket::Write(const char* buf, return ERR_IO_PENDING; } -int TCPClientSocket::CreateSocket(const addrinfo* ai) { +int TCPClientSocketLibevent::CreateSocket(const addrinfo* ai) { socket_ = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); if (socket_ == kInvalidSocket) return MapPosixError(errno); @@ -250,18 +250,18 @@ int TCPClientSocket::CreateSocket(const addrinfo* ai) { return OK; } -void TCPClientSocket::DoCallback(int rv) { - DCHECK(rv != ERR_IO_PENDING); - DCHECK(callback_); +void TCPClientSocketLibevent::DoReadCallback(int rv) { + DCHECK_NE(rv, ERR_IO_PENDING); + DCHECK(read_callback_); - // since Run may result in Read being called, clear callback_ up front. - CompletionCallback* c = callback_; - callback_ = NULL; + // since Run may result in Read being called, clear read_callback_ up front. + CompletionCallback* c = read_callback_; + read_callback_ = NULL; c->Run(rv); } -void TCPClientSocket::DoWriteCallback(int rv) { - DCHECK(rv != ERR_IO_PENDING); +void TCPClientSocketLibevent::DoWriteCallback(int rv) { + DCHECK_NE(rv, ERR_IO_PENDING); DCHECK(write_callback_); // since Run may result in Write being called, clear write_callback_ up front. @@ -270,7 +270,7 @@ void TCPClientSocket::DoWriteCallback(int rv) { c->Run(rv); } -void TCPClientSocket::DidCompleteConnect() { +void TCPClientSocketLibevent::DidCompleteConnect() { int result = ERR_UNEXPECTED; TRACE_EVENT_END("socket.connect", this, ""); @@ -295,7 +295,7 @@ void TCPClientSocket::DidCompleteConnect() { const addrinfo* next = current_ai_->ai_next; Disconnect(); current_ai_ = next; - result = Connect(callback_); + result = Connect(read_callback_); } else { result = MapPosixError(error_code); socket_watcher_.StopWatchingFileDescriptor(); @@ -303,13 +303,13 @@ void TCPClientSocket::DidCompleteConnect() { } if (result != ERR_IO_PENDING) { - DoCallback(result); + DoReadCallback(result); } } -void TCPClientSocket::DidCompleteRead() { +void TCPClientSocketLibevent::DidCompleteRead() { int bytes_transferred; - bytes_transferred = read(socket_, buf_, buf_len_); + bytes_transferred = read(socket_, read_buf_, read_buf_len_); int result; if (bytes_transferred >= 0) { @@ -321,14 +321,14 @@ void TCPClientSocket::DidCompleteRead() { } if (result != ERR_IO_PENDING) { - buf_ = NULL; - buf_len_ = 0; + read_buf_ = NULL; + read_buf_len_ = 0; socket_watcher_.StopWatchingFileDescriptor(); - DoCallback(result); + DoReadCallback(result); } } -void TCPClientSocket::DidCompleteWrite() { +void TCPClientSocketLibevent::DidCompleteWrite() { int bytes_transferred; bytes_transferred = write(socket_, write_buf_, write_buf_len_); @@ -349,15 +349,15 @@ void TCPClientSocket::DidCompleteWrite() { } } -void TCPClientSocket::OnFileCanReadWithoutBlocking(int fd) { +void TCPClientSocketLibevent::OnFileCanReadWithoutBlocking(int fd) { // When a socket connects it signals both Read and Write, we handle // DidCompleteConnect() in the write handler. - if (!waiting_connect_ && callback_) { + if (!waiting_connect_ && read_callback_) { DidCompleteRead(); } } -void TCPClientSocket::OnFileCanWriteWithoutBlocking(int fd) { +void TCPClientSocketLibevent::OnFileCanWriteWithoutBlocking(int fd) { if (waiting_connect_) { DidCompleteConnect(); } else if (write_callback_) { @@ -365,7 +365,8 @@ void TCPClientSocket::OnFileCanWriteWithoutBlocking(int fd) { } } -int TCPClientSocket::GetPeerName(struct sockaddr *name, socklen_t *namelen) { +int TCPClientSocketLibevent::GetPeerName(struct sockaddr *name, + socklen_t *namelen) { return ::getpeername(socket_, name, namelen); } diff --git a/net/base/tcp_client_socket_libevent.h b/net/base/tcp_client_socket_libevent.h new file mode 100644 index 0000000..c933bfa --- /dev/null +++ b/net/base/tcp_client_socket_libevent.h @@ -0,0 +1,92 @@ +// Copyright (c) 2006-2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_BASE_TCP_CLIENT_SOCKET_LIBEVENT_H_ +#define NET_BASE_TCP_CLIENT_SOCKET_LIBEVENT_H_ + +#include <sys/socket.h> // for struct sockaddr + +#include "base/message_loop.h" +#include "net/base/address_list.h" +#include "net/base/client_socket.h" +#include "net/base/completion_callback.h" + +struct event; // From libevent + +namespace net { + +// A client socket that uses TCP as the transport layer. +class TCPClientSocketLibevent : public ClientSocket, + public MessageLoopForIO::Watcher { + public: + // The IP address(es) and port number to connect to. The TCP socket will try + // each IP address in the list until it succeeds in establishing a + // connection. + explicit TCPClientSocketLibevent(const AddressList& addresses); + + ~TCPClientSocketLibevent(); + + // ClientSocket methods: + virtual int Connect(CompletionCallback* callback); + virtual void Disconnect(); + virtual bool IsConnected() const; + virtual bool IsConnectedAndIdle() const; + + // Socket methods: + // Multiple outstanding requests are not supported. + // Full duplex mode (reading and writing at the same time) is supported + virtual int Read(char* buf, int buf_len, CompletionCallback* callback); + virtual int Write(const char* buf, int buf_len, CompletionCallback* callback); + + // Identical to posix system call of same name + // Needed by ssl_client_socket_nss + virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen); + + private: + // Called by MessagePumpLibevent when the socket is ready to do I/O + void OnFileCanReadWithoutBlocking(int fd); + void OnFileCanWriteWithoutBlocking(int fd); + + void DoReadCallback(int rv); + void DoWriteCallback(int rv); + void DidCompleteRead(); + void DidCompleteWrite(); + void DidCompleteConnect(); + + int CreateSocket(const struct addrinfo* ai); + + int socket_; + + // The list of addresses we should try in order to establish a connection. + AddressList addresses_; + + // Where we are in above list, or NULL if all addrinfos have been tried. + const struct addrinfo* current_ai_; + + // Whether we're currently waiting for connect() to complete + bool waiting_connect_; + + // The socket's libevent wrapper + MessageLoopForIO::FileDescriptorWatcher socket_watcher_; + + // The buffer used by OnSocketReady to retry Read requests + char* read_buf_; + int read_buf_len_; + + // The buffer used by OnSocketReady to retry Write requests + const char* write_buf_; + int write_buf_len_; + + // External callback; called when read is complete. + CompletionCallback* read_callback_; + + // External callback; called when write is complete. + CompletionCallback* write_callback_; + + DISALLOW_COPY_AND_ASSIGN(TCPClientSocketLibevent); +}; + +} // namespace net + +#endif // NET_BASE_TCP_CLIENT_SOCKET_LIBEVENT_H_ diff --git a/net/base/tcp_client_socket_unittest.cc b/net/base/tcp_client_socket_unittest.cc index 536528b..1c40ac6 100644 --- a/net/base/tcp_client_socket_unittest.cc +++ b/net/base/tcp_client_socket_unittest.cc @@ -35,6 +35,7 @@ class TCPClientSocketTest protected: int listen_port_; + scoped_ptr<net::TCPClientSocket> sock_; private: scoped_refptr<ListenSocket> listen_sock_; @@ -48,34 +49,32 @@ void TCPClientSocketTest::SetUp() { ListenSocket *sock = NULL; int port; // Range of ports to listen on. Shouldn't need to try many. - static const int kMinPort = 10100; - static const int kMaxPort = 10200; + const int kMinPort = 10100; + const int kMaxPort = 10200; #if defined(OS_WIN) net::EnsureWinsockInit(); #endif for (port = kMinPort; port < kMaxPort; port++) { sock = ListenSocket::Listen("127.0.0.1", port, this); if (sock) - break; + break; } ASSERT_TRUE(sock != NULL); listen_sock_ = sock; listen_port_ = port; -} -TEST_F(TCPClientSocketTest, Connect) { net::AddressList addr; net::HostResolver resolver; - TestCompletionCallback callback; - int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL); - EXPECT_EQ(rv, net::OK); - - net::TCPClientSocket sock(addr); + CHECK(rv == net::OK); + sock_.reset(new net::TCPClientSocket(addr)); +} - EXPECT_FALSE(sock.IsConnected()); +TEST_F(TCPClientSocketTest, Connect) { + TestCompletionCallback callback; + EXPECT_FALSE(sock_->IsConnected()); - rv = sock.Connect(&callback); + int rv = sock_->Connect(&callback); if (rv != net::OK) { ASSERT_EQ(rv, net::ERR_IO_PENDING); @@ -83,10 +82,10 @@ TEST_F(TCPClientSocketTest, Connect) { EXPECT_EQ(rv, net::OK); } - EXPECT_TRUE(sock.IsConnected()); + EXPECT_TRUE(sock_->IsConnected()); - sock.Disconnect(); - EXPECT_FALSE(sock.IsConnected()); + sock_->Disconnect(); + EXPECT_FALSE(sock_->IsConnected()); } // TODO(wtc): Add unit tests for IsConnectedAndIdle: @@ -94,19 +93,8 @@ TEST_F(TCPClientSocketTest, Connect) { // - Server sends data unexpectedly. TEST_F(TCPClientSocketTest, Read) { - net::AddressList addr; - net::HostResolver resolver; TestCompletionCallback callback; - - int rv = resolver.Resolve("localhost", listen_port_, &addr, &callback); - EXPECT_EQ(rv, net::ERR_IO_PENDING); - - rv = callback.WaitForResult(); - EXPECT_EQ(rv, net::OK); - - net::TCPClientSocket sock(addr); - - rv = sock.Connect(&callback); + int rv = sock_->Connect(&callback); if (rv != net::OK) { ASSERT_EQ(rv, net::ERR_IO_PENDING); @@ -115,7 +103,7 @@ TEST_F(TCPClientSocketTest, Read) { } const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - rv = sock.Write(request_text, arraysize(request_text) - 1, &callback); + rv = sock_->Write(request_text, arraysize(request_text) - 1, &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) { @@ -125,7 +113,7 @@ TEST_F(TCPClientSocketTest, Read) { char buf[4096]; for (;;) { - rv = sock.Read(buf, sizeof(buf), &callback); + rv = sock_->Read(buf, sizeof(buf), &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -138,16 +126,8 @@ TEST_F(TCPClientSocketTest, Read) { } TEST_F(TCPClientSocketTest, Read_SmallChunks) { - net::AddressList addr; - net::HostResolver resolver; TestCompletionCallback callback; - - int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL); - EXPECT_EQ(rv, net::OK); - - net::TCPClientSocket sock(addr); - - rv = sock.Connect(&callback); + int rv = sock_->Connect(&callback); if (rv != net::OK) { ASSERT_EQ(rv, net::ERR_IO_PENDING); @@ -156,7 +136,7 @@ TEST_F(TCPClientSocketTest, Read_SmallChunks) { } const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - rv = sock.Write(request_text, arraysize(request_text) - 1, &callback); + rv = sock_->Write(request_text, arraysize(request_text) - 1, &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) { @@ -166,7 +146,7 @@ TEST_F(TCPClientSocketTest, Read_SmallChunks) { char buf[1]; for (;;) { - rv = sock.Read(buf, sizeof(buf), &callback); + rv = sock_->Read(buf, sizeof(buf), &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -179,16 +159,8 @@ TEST_F(TCPClientSocketTest, Read_SmallChunks) { } TEST_F(TCPClientSocketTest, Read_Interrupted) { - net::AddressList addr; - net::HostResolver resolver; TestCompletionCallback callback; - - int rv = resolver.Resolve("localhost", listen_port_, &addr, NULL); - EXPECT_EQ(rv, net::OK); - - net::TCPClientSocket sock(addr); - - rv = sock.Connect(&callback); + int rv = sock_->Connect(&callback); if (rv != net::OK) { ASSERT_EQ(rv, net::ERR_IO_PENDING); @@ -197,7 +169,7 @@ TEST_F(TCPClientSocketTest, Read_Interrupted) { } const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - rv = sock.Write(request_text, arraysize(request_text) - 1, &callback); + rv = sock_->Write(request_text, arraysize(request_text) - 1, &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) { @@ -207,7 +179,7 @@ TEST_F(TCPClientSocketTest, Read_Interrupted) { // Do a partial read and then exit. This test should not crash! char buf[512]; - rv = sock.Read(buf, sizeof(buf), &callback); + rv = sock_->Read(buf, sizeof(buf), &callback); EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); if (rv == net::ERR_IO_PENDING) @@ -215,3 +187,81 @@ TEST_F(TCPClientSocketTest, Read_Interrupted) { EXPECT_NE(rv, 0); } + +TEST_F(TCPClientSocketTest, FullDuplex_ReadFirst) { + TestCompletionCallback callback; + int rv = sock_->Connect(&callback); + if (rv != net::OK) { + ASSERT_EQ(rv, net::ERR_IO_PENDING); + + rv = callback.WaitForResult(); + EXPECT_EQ(rv, net::OK); + } + + char buf[4096]; + rv = sock_->Read(buf, sizeof(buf), &callback); + EXPECT_EQ(net::ERR_IO_PENDING, rv); + + const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; + TestCompletionCallback write_callback; + rv = sock_->Write(request_text, arraysize(request_text) - 1, &write_callback); + EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + + if (rv == net::ERR_IO_PENDING) { + rv = write_callback.WaitForResult(); + EXPECT_EQ(rv, static_cast<int>(arraysize(request_text) - 1)); + } + + rv = callback.WaitForResult(); + EXPECT_GE(rv, 0); + while (rv > 0) { + rv = sock_->Read(buf, sizeof(buf), &callback); + EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + + if (rv == net::ERR_IO_PENDING) + rv = callback.WaitForResult(); + + EXPECT_GE(rv, 0); + if (rv <= 0) + break; + } +} + +TEST_F(TCPClientSocketTest, FullDuplex_WriteFirst) { + TestCompletionCallback callback; + int rv = sock_->Connect(&callback); + if (rv != net::OK) { + ASSERT_EQ(rv, net::ERR_IO_PENDING); + + rv = callback.WaitForResult(); + EXPECT_EQ(rv, net::OK); + } + + const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; + TestCompletionCallback write_callback; + rv = sock_->Write(request_text, arraysize(request_text) - 1, &write_callback); + EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + + char buf[4096]; + int read_rv = sock_->Read(buf, sizeof(buf), &callback); + EXPECT_TRUE(read_rv >= 0 || read_rv == net::ERR_IO_PENDING); + + if (rv == net::ERR_IO_PENDING) { + rv = write_callback.WaitForResult(); + EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); + } + + rv = callback.WaitForResult(); + EXPECT_GE(rv, 0); + while (rv > 0) { + rv = sock_->Read(buf, sizeof(buf), &callback); + EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + + if (rv == net::ERR_IO_PENDING) + rv = callback.WaitForResult(); + + EXPECT_GE(rv, 0); + if (rv <= 0) + break; + } +} diff --git a/net/base/tcp_client_socket_win.cc b/net/base/tcp_client_socket_win.cc index 9d69505..50e99a3 100644 --- a/net/base/tcp_client_socket_win.cc +++ b/net/base/tcp_client_socket_win.cc @@ -2,8 +2,10 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/base/tcp_client_socket.h" +#include "net/base/tcp_client_socket_win.h" +#include "base/basictypes.h" +#include "base/compiler_specific.h" #include "base/memory_debug.h" #include "base/string_util.h" #include "base/sys_info.h" @@ -13,9 +15,25 @@ namespace net { +namespace { + +// Waits for the (manual-reset) event object to become signaled and resets +// it. Called after a Winsock function succeeds synchronously +// +// Our testing shows that except in rare cases (when running inside QEMU), +// the event object is already signaled at this point, so we just call this +// method on the IO thread to avoid a context switch. +void WaitForAndResetEvent(WSAEVENT hEvent) { + // TODO(wtc): Remove the CHECKs after enough testing. + DWORD wait_rv = WaitForSingleObject(hEvent, INFINITE); + CHECK(wait_rv == WAIT_OBJECT_0); + BOOL ok = WSAResetEvent(hEvent); + CHECK(ok); +} + //----------------------------------------------------------------------------- -static int MapWinsockError(DWORD err) { +int MapWinsockError(DWORD err) { // There are numerous Winsock error codes, but these are the ones we thus far // find interesting. switch (err) { @@ -52,23 +70,31 @@ static int MapWinsockError(DWORD err) { } } +} // namespace + //----------------------------------------------------------------------------- -TCPClientSocket::TCPClientSocket(const AddressList& addresses) +TCPClientSocketWin::TCPClientSocketWin(const AddressList& addresses) : socket_(INVALID_SOCKET), addresses_(addresses), current_ai_(addresses_.head()), - wait_state_(NOT_WAITING), - callback_(NULL) { - memset(&overlapped_, 0, sizeof(overlapped_)); + waiting_connect_(false), + waiting_read_(false), + waiting_write_(false), + ALLOW_THIS_IN_INITIALIZER_LIST(reader_(this)), + ALLOW_THIS_IN_INITIALIZER_LIST(writer_(this)), + read_callback_(NULL), + write_callback_(NULL) { + memset(&read_overlapped_, 0, sizeof(read_overlapped_)); + memset(&write_overlapped_, 0, sizeof(write_overlapped_)); EnsureWinsockInit(); } -TCPClientSocket::~TCPClientSocket() { +TCPClientSocketWin::~TCPClientSocketWin() { Disconnect(); } -int TCPClientSocket::Connect(CompletionCallback* callback) { +int TCPClientSocketWin::Connect(CompletionCallback* callback) { // If already connected, then just return OK. if (socket_ != INVALID_SOCKET) return OK; @@ -82,14 +108,16 @@ int TCPClientSocket::Connect(CompletionCallback* callback) { return rv; // WSACreateEvent creates a manual-reset event object. - overlapped_.hEvent = WSACreateEvent(); + read_overlapped_.hEvent = WSACreateEvent(); // WSAEventSelect sets the socket to non-blocking mode as a side effect. // Our connect() and recv() calls require that the socket be non-blocking. - WSAEventSelect(socket_, overlapped_.hEvent, FD_CONNECT); + WSAEventSelect(socket_, read_overlapped_.hEvent, FD_CONNECT); + + write_overlapped_.hEvent = WSACreateEvent(); if (!connect(socket_, ai->ai_addr, static_cast<int>(ai->ai_addrlen))) { // Connected without waiting! - WaitForAndResetEvent(); + WaitForAndResetEvent(read_overlapped_.hEvent); TRACE_EVENT_END("socket.connect", this, ""); return OK; } @@ -100,26 +128,29 @@ int TCPClientSocket::Connect(CompletionCallback* callback) { return MapWinsockError(err); } - watcher_.StartWatching(overlapped_.hEvent, this); - wait_state_ = WAITING_CONNECT; - callback_ = callback; + read_watcher_.StartWatching(read_overlapped_.hEvent, &reader_); + waiting_connect_ = true; + read_callback_ = callback; return ERR_IO_PENDING; } -void TCPClientSocket::Disconnect() { +void TCPClientSocketWin::Disconnect() { if (socket_ == INVALID_SOCKET) return; TRACE_EVENT_INSTANT("socket.disconnect", this, ""); // Make sure the message loop is not watching this object anymore. - watcher_.StopWatching(); + read_watcher_.StopWatching(); + write_watcher_.StopWatching(); // Cancel any pending IO and wait for it to be aborted. - if (wait_state_ == WAITING_READ || wait_state_ == WAITING_WRITE) { + if (waiting_read_ || waiting_write_) { CancelIo(reinterpret_cast<HANDLE>(socket_)); - WaitForSingleObject(overlapped_.hEvent, INFINITE); - wait_state_ = NOT_WAITING; + if (waiting_read_) + WaitForSingleObject(read_overlapped_.hEvent, INFINITE); + if (waiting_write_) + WaitForSingleObject(write_overlapped_.hEvent, INFINITE); } // In most socket implementations, closing a socket results in a graceful @@ -131,15 +162,21 @@ void TCPClientSocket::Disconnect() { closesocket(socket_); socket_ = INVALID_SOCKET; - WSACloseEvent(overlapped_.hEvent); - memset(&overlapped_, 0, sizeof(overlapped_)); + WSACloseEvent(read_overlapped_.hEvent); + memset(&read_overlapped_, 0, sizeof(read_overlapped_)); + WSACloseEvent(write_overlapped_.hEvent); + memset(&write_overlapped_, 0, sizeof(write_overlapped_)); // Reset for next time. current_ai_ = addresses_.head(); + + waiting_read_ = false; + waiting_write_ = false; + waiting_connect_ = false; } -bool TCPClientSocket::IsConnected() const { - if (socket_ == INVALID_SOCKET || wait_state_ == WAITING_CONNECT) +bool TCPClientSocketWin::IsConnected() const { + if (socket_ == INVALID_SOCKET || waiting_connect_) return false; // Check if connection is alive. @@ -153,8 +190,8 @@ bool TCPClientSocket::IsConnected() const { return true; } -bool TCPClientSocket::IsConnectedAndIdle() const { - if (socket_ == INVALID_SOCKET || wait_state_ == WAITING_CONNECT) +bool TCPClientSocketWin::IsConnectedAndIdle() const { + if (socket_ == INVALID_SOCKET || waiting_connect_) return false; // Check if connection is alive and we haven't received any data @@ -169,23 +206,24 @@ bool TCPClientSocket::IsConnectedAndIdle() const { return true; } -int TCPClientSocket::Read(char* buf, - int buf_len, - CompletionCallback* callback) { - DCHECK(socket_ != INVALID_SOCKET); - DCHECK(wait_state_ == NOT_WAITING); - DCHECK(!callback_); +int TCPClientSocketWin::Read(char* buf, + int buf_len, + CompletionCallback* callback) { + DCHECK_NE(socket_, INVALID_SOCKET); + DCHECK(!waiting_read_); + DCHECK(!read_callback_); - buffer_.len = buf_len; - buffer_.buf = buf; + read_buffer_.len = buf_len; + read_buffer_.buf = buf; TRACE_EVENT_BEGIN("socket.read", this, ""); // TODO(wtc): Remove the CHECK after enough testing. - CHECK(WaitForSingleObject(overlapped_.hEvent, 0) == WAIT_TIMEOUT); + CHECK(WaitForSingleObject(read_overlapped_.hEvent, 0) == WAIT_TIMEOUT); DWORD num, flags = 0; - int rv = WSARecv(socket_, &buffer_, 1, &num, &flags, &overlapped_, NULL); + int rv = WSARecv( + socket_, &read_buffer_, 1, &num, &flags, &read_overlapped_, NULL); if (rv == 0) { - WaitForAndResetEvent(); + WaitForAndResetEvent(read_overlapped_.hEvent); TRACE_EVENT_END("socket.read", this, StringPrintf("%d bytes", num)); // Because of how WSARecv fills memory when used asynchronously, Purify @@ -194,51 +232,52 @@ int TCPClientSocket::Read(char* buf, // individual bytes. We override that in PURIFY builds to avoid the false // error reports. // See bug 5297. - base::MemoryDebug::MarkAsInitialized(buffer_.buf, num); + base::MemoryDebug::MarkAsInitialized(read_buffer_.buf, num); return static_cast<int>(num); } int err = WSAGetLastError(); if (err == WSA_IO_PENDING) { - watcher_.StartWatching(overlapped_.hEvent, this); - wait_state_ = WAITING_READ; - callback_ = callback; + read_watcher_.StartWatching(read_overlapped_.hEvent, &reader_); + waiting_read_ = true; + read_callback_ = callback; return ERR_IO_PENDING; } return MapWinsockError(err); } -int TCPClientSocket::Write(const char* buf, - int buf_len, - CompletionCallback* callback) { - DCHECK(socket_ != INVALID_SOCKET); - DCHECK(wait_state_ == NOT_WAITING); - DCHECK(!callback_); - DCHECK(buf_len > 0); +int TCPClientSocketWin::Write(const char* buf, + int buf_len, + CompletionCallback* callback) { + DCHECK_NE(socket_, INVALID_SOCKET); + DCHECK(!waiting_write_); + DCHECK(!write_callback_); + DCHECK_GT(buf_len, 0); - buffer_.len = buf_len; - buffer_.buf = const_cast<char*>(buf); + write_buffer_.len = buf_len; + write_buffer_.buf = const_cast<char*>(buf); TRACE_EVENT_BEGIN("socket.write", this, ""); // TODO(wtc): Remove the CHECK after enough testing. - CHECK(WaitForSingleObject(overlapped_.hEvent, 0) == WAIT_TIMEOUT); + CHECK(WaitForSingleObject(write_overlapped_.hEvent, 0) == WAIT_TIMEOUT); DWORD num; - int rv = WSASend(socket_, &buffer_, 1, &num, 0, &overlapped_, NULL); + int rv = + WSASend(socket_, &write_buffer_, 1, &num, 0, &write_overlapped_, NULL); if (rv == 0) { - WaitForAndResetEvent(); + WaitForAndResetEvent(write_overlapped_.hEvent); TRACE_EVENT_END("socket.write", this, StringPrintf("%d bytes", num)); return static_cast<int>(num); } int err = WSAGetLastError(); if (err == WSA_IO_PENDING) { - watcher_.StartWatching(overlapped_.hEvent, this); - wait_state_ = WAITING_WRITE; - callback_ = callback; + read_watcher_.StartWatching(read_overlapped_.hEvent, &writer_); + waiting_write_ = true; + read_callback_ = callback; return ERR_IO_PENDING; } return MapWinsockError(err); } -int TCPClientSocket::CreateSocket(const struct addrinfo* ai) { +int TCPClientSocketWin::CreateSocket(const struct addrinfo* ai) { socket_ = WSASocket(ai->ai_family, ai->ai_socktype, ai->ai_protocol, NULL, 0, WSA_FLAG_OVERLAPPED); if (socket_ == INVALID_SOCKET) { @@ -302,29 +341,38 @@ int TCPClientSocket::CreateSocket(const struct addrinfo* ai) { return OK; } -void TCPClientSocket::DoCallback(int rv) { - DCHECK(rv != ERR_IO_PENDING); - DCHECK(callback_); +void TCPClientSocketWin::DoReadCallback(int rv) { + DCHECK_NE(rv, ERR_IO_PENDING); + DCHECK(read_callback_); + + // since Run may result in Read being called, clear read_callback_ up front. + CompletionCallback* c = read_callback_; + read_callback_ = NULL; + c->Run(rv); +} + +void TCPClientSocketWin::DoWriteCallback(int rv) { + DCHECK_NE(rv, ERR_IO_PENDING); + DCHECK(write_callback_); - // since Run may result in Read being called, clear callback_ up front. - CompletionCallback* c = callback_; - callback_ = NULL; + // since Run may result in Read being called, clear read_callback_ up front. + CompletionCallback* c = write_callback_; + write_callback_ = NULL; c->Run(rv); } -void TCPClientSocket::DidCompleteConnect() { +void TCPClientSocketWin::DidCompleteConnect() { int result; TRACE_EVENT_END("socket.connect", this, ""); - wait_state_ = NOT_WAITING; + waiting_connect_ = false; WSANETWORKEVENTS events; - int rv = WSAEnumNetworkEvents(socket_, overlapped_.hEvent, &events); + int rv = WSAEnumNetworkEvents(socket_, read_overlapped_.hEvent, &events); if (rv == SOCKET_ERROR) { NOTREACHED(); result = MapWinsockError(WSAGetLastError()); } else if (events.lNetworkEvents & FD_CONNECT) { - wait_state_ = NOT_WAITING; DWORD error_code = static_cast<DWORD>(events.iErrorCode[FD_CONNECT_BIT]); if (current_ai_->ai_next && ( error_code == WSAEADDRNOTAVAIL || @@ -337,7 +385,7 @@ void TCPClientSocket::DidCompleteConnect() { const struct addrinfo* next = current_ai_->ai_next; Disconnect(); current_ai_ = next; - result = Connect(callback_); + result = Connect(read_callback_); } else { result = MapWinsockError(error_code); } @@ -347,46 +395,41 @@ void TCPClientSocket::DidCompleteConnect() { } if (result != ERR_IO_PENDING) - DoCallback(result); + DoReadCallback(result); } -void TCPClientSocket::DidCompleteIO() { - DWORD num_bytes, flags; - BOOL ok = WSAGetOverlappedResult( - socket_, &overlapped_, &num_bytes, FALSE, &flags); - WSAResetEvent(overlapped_.hEvent); - if (wait_state_ == WAITING_READ) { - TRACE_EVENT_END("socket.read", this, StringPrintf("%d bytes", num_bytes)); +void TCPClientSocketWin::ReadDelegate::OnObjectSignaled(HANDLE object) { + DCHECK_EQ(object, tcp_socket_->read_overlapped_.hEvent); + + if (tcp_socket_->waiting_connect_) { + tcp_socket_->DidCompleteConnect(); } else { - TRACE_EVENT_END("socket.write", this, StringPrintf("%d bytes", num_bytes)); + DWORD num_bytes, flags; + BOOL ok = WSAGetOverlappedResult( + tcp_socket_->socket_, &tcp_socket_->read_overlapped_, &num_bytes, + FALSE, &flags); + WSAResetEvent(object); + TRACE_EVENT_END("socket.read", tcp_socket_, + StringPrintf("%d bytes", num_bytes)); + tcp_socket_->waiting_read_ = false; + tcp_socket_->DoReadCallback( + ok ? num_bytes : MapWinsockError(WSAGetLastError())); } - wait_state_ = NOT_WAITING; - DoCallback(ok ? num_bytes : MapWinsockError(WSAGetLastError())); } -void TCPClientSocket::OnObjectSignaled(HANDLE object) { - DCHECK(object == overlapped_.hEvent); - - switch (wait_state_) { - case WAITING_CONNECT: - DidCompleteConnect(); - break; - case WAITING_READ: - case WAITING_WRITE: - DidCompleteIO(); - break; - default: - NOTREACHED(); - break; - } -} +void TCPClientSocketWin::WriteDelegate::OnObjectSignaled(HANDLE object) { + DCHECK_EQ(object, tcp_socket_->write_overlapped_.hEvent); -void TCPClientSocket::WaitForAndResetEvent() { - // TODO(wtc): Remove the CHECKs after enough testing. - DWORD wait_rv = WaitForSingleObject(overlapped_.hEvent, INFINITE); - CHECK(wait_rv == WAIT_OBJECT_0); - BOOL ok = WSAResetEvent(overlapped_.hEvent); - CHECK(ok); + DWORD num_bytes, flags; + BOOL ok = WSAGetOverlappedResult( + tcp_socket_->socket_, &tcp_socket_->write_overlapped_, &num_bytes, + FALSE, &flags); + WSAResetEvent(object); + TRACE_EVENT_END("socket.write", tcp_socket_, + StringPrintf("%d bytes", num_bytes)); + tcp_socket_->waiting_write_ = false; + tcp_socket_->DoWriteCallback( + ok ? num_bytes : MapWinsockError(WSAGetLastError())); } } // namespace net diff --git a/net/base/tcp_client_socket_win.h b/net/base/tcp_client_socket_win.h new file mode 100644 index 0000000..63f2b93 --- /dev/null +++ b/net/base/tcp_client_socket_win.h @@ -0,0 +1,114 @@ +// Copyright (c) 2006-2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_BASE_TCP_CLIENT_SOCKET_WIN_H_ +#define NET_BASE_TCP_CLIENT_SOCKET_WIN_H_ + +#include <ws2tcpip.h> + +#include "base/object_watcher.h" +#include "net/base/address_list.h" +#include "net/base/client_socket.h" +#include "net/base/completion_callback.h" + +namespace net { + +class TCPClientSocketWin : public ClientSocket { + public: + // The IP address(es) and port number to connect to. The TCP socket will try + // each IP address in the list until it succeeds in establishing a + // connection. + explicit TCPClientSocketWin(const AddressList& addresses); + + ~TCPClientSocketWin(); + + // ClientSocket methods: + virtual int Connect(CompletionCallback* callback); + virtual void Disconnect(); + virtual bool IsConnected() const; + virtual bool IsConnectedAndIdle() const; + + // Socket methods: + // Multiple outstanding requests are not supported. + // Full duplex mode (reading and writing at the same time) is supported + virtual int Read(char* buf, int buf_len, CompletionCallback* callback); + virtual int Write(const char* buf, int buf_len, CompletionCallback* callback); + + private: + class ReadDelegate : public base::ObjectWatcher::Delegate { + public: + explicit ReadDelegate(TCPClientSocketWin* tcp_socket) + : tcp_socket_(tcp_socket) { } + virtual ~ReadDelegate() { } + + // base::ObjectWatcher::Delegate methods: + virtual void OnObjectSignaled(HANDLE object); + + private: + TCPClientSocketWin* const tcp_socket_; + }; + + class WriteDelegate : public base::ObjectWatcher::Delegate { + public: + explicit WriteDelegate(TCPClientSocketWin* tcp_socket) + : tcp_socket_(tcp_socket) { } + virtual ~WriteDelegate() { } + + // base::ObjectWatcher::Delegate methods: + virtual void OnObjectSignaled(HANDLE object); + + private: + TCPClientSocketWin* const tcp_socket_; + }; + + int CreateSocket(const struct addrinfo* ai); + void DoReadCallback(int rv); + void DoWriteCallback(int rv); + void DidCompleteConnect(); + + SOCKET socket_; + + // The list of addresses we should try in order to establish a connection. + AddressList addresses_; + + // Where we are in above list, or NULL if all addrinfos have been tried. + const struct addrinfo* current_ai_; + + // The various states that the socket could be in. + bool waiting_connect_; + bool waiting_read_; + bool waiting_write_; + + // The separate OVERLAPPED variables for asynchronous operation. + // |read_overlapped_| is used for both Connect() and Read(). + // |write_overlapped_| is only used for Write(); + OVERLAPPED read_overlapped_; + OVERLAPPED write_overlapped_; + + // The buffers used in Read() and Write(). + WSABUF read_buffer_; + WSABUF write_buffer_; + + // |reader_| handles the signals from |read_watcher_|. + ReadDelegate reader_; + // |writer_| handles the signals from |write_watcher_|. + WriteDelegate writer_; + + // |read_watcher_| watches for events from Connect() and Read(). + base::ObjectWatcher read_watcher_; + // |write_watcher_| watches for events from Write(); + base::ObjectWatcher write_watcher_; + + // External callback; called when connect or read is complete. + CompletionCallback* read_callback_; + + // External callback; called when write is complete. + CompletionCallback* write_callback_; + + DISALLOW_COPY_AND_ASSIGN(TCPClientSocketWin); +}; + +} // namespace net + +#endif // NET_BASE_TCP_CLIENT_SOCKET_WIN_H_ diff --git a/net/net.gyp b/net/net.gyp index f1a4805..fa6c804 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -121,7 +121,9 @@ 'base/ssl_test_util.cc', 'base/tcp_client_socket.h', 'base/tcp_client_socket_libevent.cc', + 'base/tcp_client_socket_libevent.h', 'base/tcp_client_socket_win.cc', + 'base/tcp_client_socket_win.h', 'base/telnet_server.cc', 'base/telnet_server.h', 'base/upload_data.cc', |