diff options
Diffstat (limited to 'net/base')
| -rw-r--r-- | net/base/tcp_client_socket.h | 65 | ||||
| -rw-r--r-- | net/base/tcp_client_socket_unittest.cc | 16 |
2 files changed, 81 insertions, 0 deletions
diff --git a/net/base/tcp_client_socket.h b/net/base/tcp_client_socket.h index 5fca519..07d8351 100644 --- a/net/base/tcp_client_socket.h +++ b/net/base/tcp_client_socket.h @@ -18,9 +18,13 @@ struct event; // From libevent #endif #include "base/scoped_ptr.h" +#include "base/task.h" +#include "base/thread.h" +#include "base/waitable_event.h" #include "net/base/address_list.h" #include "net/base/client_socket.h" #include "net/base/completion_callback.h" +#include "net/base/net_errors.h" namespace net { @@ -127,6 +131,67 @@ class TCPClientSocket : public ClientSocket, void DidCompleteConnect(); }; +// Tiny helper class to do a synchronous connect, +// in lieu of directly supporting that in TcpClientSocket. +// This avoids cluttering the main codepath with code only used by unit tests. +// TODO(dkegel): move this to its own header file. +class TCPClientSocketSyncConnector + : public base::RefCounted<TCPClientSocketSyncConnector> { + public: + // Connect given socket synchronously. + // Returns network error code. + static int Connect(net::TCPClientSocket* sock) { + // Start up a throwaway IO thread just for this. + // TODO(port): use some existing thread pool instead? + base::Thread io_thread("SyncConnect"); + base::Thread::Options options; + options.message_loop_type = MessageLoop::TYPE_IO; + io_thread.StartWithOptions(options); + + // Post a request to do the connect on that thread. + scoped_refptr<TCPClientSocketSyncConnector> connector = + new TCPClientSocketSyncConnector(sock); + io_thread.message_loop()->PostTask(FROM_HERE, NewRunnableMethod(connector.get(), + &net::TCPClientSocketSyncConnector::DoConnect)); + connector->Wait(); + return connector->GetError(); + } + + private: + // Start a connect. Must be called on an IO thread. + void DoConnect() { + net_error_ = sock_->Connect(&connect_callback_); + if (net_error_ != ERR_IO_PENDING) + event_.Signal(); + } + + // Callback called on same IO thread when connection complete. + void ConnectDone(int rv) { + net_error_ = rv; + event_.Signal(); + } + + // Call this after posting a call to DoConnect(). + void Wait() { event_.Wait(); } + + // Call this after Wait() if you need the final error code from the connect. + int GetError() { return net_error_; } + + // sock is owned by caller, but must remain valid while this object lives. + explicit TCPClientSocketSyncConnector(TCPClientSocket* sock) : + event_(false, false), + sock_(sock), + net_error_(0), + connect_callback_(this, &net::TCPClientSocketSyncConnector::ConnectDone) { + } + + base::WaitableEvent event_; + net::TCPClientSocket* sock_; + int net_error_; + net::CompletionCallbackImpl<TCPClientSocketSyncConnector> connect_callback_; + DISALLOW_COPY_AND_ASSIGN(TCPClientSocketSyncConnector); +}; + } // namespace net #endif // NET_BASE_TCP_CLIENT_SOCKET_H_ diff --git a/net/base/tcp_client_socket_unittest.cc b/net/base/tcp_client_socket_unittest.cc index 18e6c21..41c477a 100644 --- a/net/base/tcp_client_socket_unittest.cc +++ b/net/base/tcp_client_socket_unittest.cc @@ -49,6 +49,22 @@ TEST_F(TCPClientSocketTest, Connect) { EXPECT_FALSE(sock.IsConnected()); } +TEST_F(TCPClientSocketTest, SyncConnect) { + net::AddressList addr; + net::HostResolver resolver; + + int rv = resolver.Resolve("www.google.com", 80, &addr, NULL); + EXPECT_EQ(rv, net::OK); + + net::TCPClientSocket sock(addr); + + EXPECT_FALSE(sock.IsConnected()); + + rv = net::TCPClientSocketSyncConnector::Connect(&sock); + EXPECT_EQ(rv, net::OK); + EXPECT_TRUE(sock.IsConnected()); +} + TEST_F(TCPClientSocketTest, Read) { net::AddressList addr; net::HostResolver resolver; |
