summaryrefslogtreecommitdiffstats
path: root/net/socket/socket_test_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'net/socket/socket_test_util.cc')
-rw-r--r--net/socket/socket_test_util.cc379
1 files changed, 379 insertions, 0 deletions
diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc
new file mode 100644
index 0000000..be98865
--- /dev/null
+++ b/net/socket/socket_test_util.cc
@@ -0,0 +1,379 @@
+// Copyright (c) 2009 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/socket_test_util.h"
+
+#include "base/basictypes.h"
+#include "base/compiler_specific.h"
+#include "base/message_loop.h"
+#include "net/base/io_buffer.h"
+#include "net/base/ssl_info.h"
+#include "net/socket/socket.h"
+#include "net/socket/ssl_client_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace {
+
+class MockClientSocket : public net::SSLClientSocket {
+ public:
+ MockClientSocket();
+
+ // ClientSocket methods:
+ virtual int Connect(net::CompletionCallback* callback) = 0;
+
+ // SSLClientSocket methods:
+ virtual void GetSSLInfo(net::SSLInfo* ssl_info);
+ virtual void GetSSLCertRequestInfo(
+ net::SSLCertRequestInfo* cert_request_info);
+ virtual void Disconnect();
+ virtual bool IsConnected() const;
+ virtual bool IsConnectedAndIdle() const;
+
+ // Socket methods:
+ virtual int Read(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback) = 0;
+ virtual int Write(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback) = 0;
+
+#if defined(OS_LINUX)
+ virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen);
+#endif
+
+ protected:
+ void RunCallbackAsync(net::CompletionCallback* callback, int result);
+ void RunCallback(int result);
+
+ ScopedRunnableMethodFactory<MockClientSocket> method_factory_;
+ net::CompletionCallback* callback_;
+ bool connected_;
+};
+
+class MockTCPClientSocket : public MockClientSocket {
+ public:
+ MockTCPClientSocket(const net::AddressList& addresses,
+ net::MockSocket* socket);
+
+ // ClientSocket methods:
+ virtual int Connect(net::CompletionCallback* callback);
+
+ // Socket methods:
+ virtual int Read(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback);
+ virtual int Write(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback);
+
+ private:
+ net::MockSocket* data_;
+ int read_offset_;
+ net::MockRead* read_data_;
+ bool need_read_data_;
+};
+
+class MockSSLClientSocket : public MockClientSocket {
+ public:
+ MockSSLClientSocket(
+ net::ClientSocket* transport_socket,
+ const std::string& hostname,
+ const net::SSLConfig& ssl_config,
+ net::MockSSLSocket* socket);
+ ~MockSSLClientSocket();
+
+ virtual void GetSSLInfo(net::SSLInfo* ssl_info);
+
+ virtual int Connect(net::CompletionCallback* callback);
+ virtual void Disconnect();
+
+ // Socket methods:
+ virtual int Read(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback);
+ virtual int Write(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback);
+
+ private:
+ class ConnectCallback;
+
+ scoped_ptr<ClientSocket> transport_;
+ net::MockSSLSocket* data_;
+};
+
+MockClientSocket::MockClientSocket()
+ : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)),
+ callback_(NULL),
+ connected_(false) {
+}
+
+void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
+ NOTREACHED();
+}
+
+void MockClientSocket::GetSSLCertRequestInfo(
+ net::SSLCertRequestInfo* cert_request_info) {
+ NOTREACHED();
+}
+
+void MockClientSocket::Disconnect() {
+ connected_ = false;
+ callback_ = NULL;
+}
+
+bool MockClientSocket::IsConnected() const {
+ return connected_;
+}
+
+bool MockClientSocket::IsConnectedAndIdle() const {
+ return connected_;
+}
+
+#if defined(OS_LINUX)
+int MockClientSocket::GetPeerName(struct sockaddr *name, socklen_t *namelen) {
+ memset(reinterpret_cast<char *>(name), 0, *namelen);
+ return net::OK;
+}
+#endif // defined(OS_LINUX)
+
+void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback,
+ int result) {
+ callback_ = callback;
+ MessageLoop::current()->PostTask(FROM_HERE,
+ method_factory_.NewRunnableMethod(
+ &MockClientSocket::RunCallback, result));
+}
+
+void MockClientSocket::RunCallback(int result) {
+ net::CompletionCallback* c = callback_;
+ callback_ = NULL;
+ if (c)
+ c->Run(result);
+}
+
+MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses,
+ net::MockSocket* socket)
+ : data_(socket),
+ read_offset_(0),
+ read_data_(NULL),
+ need_read_data_(true) {
+ DCHECK(data_);
+ data_->Reset();
+}
+
+int MockTCPClientSocket::Connect(net::CompletionCallback* callback) {
+ DCHECK(!callback_);
+ if (connected_)
+ return net::OK;
+ connected_ = true;
+ if (data_->connect_data().async) {
+ RunCallbackAsync(callback, data_->connect_data().result);
+ return net::ERR_IO_PENDING;
+ }
+ return data_->connect_data().result;
+}
+
+int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback) {
+ DCHECK(!callback_);
+ if (need_read_data_) {
+ read_data_ = data_->GetNextRead();
+ need_read_data_ = false;
+ }
+ int result = read_data_->result;
+ if (read_data_->data) {
+ if (read_data_->data_len - read_offset_ > 0) {
+ result = std::min(buf_len, read_data_->data_len - read_offset_);
+ memcpy(buf->data(), read_data_->data + read_offset_, result);
+ read_offset_ += result;
+ if (read_offset_ == read_data_->data_len) {
+ need_read_data_ = true;
+ read_offset_ = 0;
+ }
+ } else {
+ result = 0; // EOF
+ }
+ }
+ if (read_data_->async) {
+ RunCallbackAsync(callback, result);
+ return net::ERR_IO_PENDING;
+ }
+ return result;
+}
+
+int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback) {
+ DCHECK(buf);
+ DCHECK(buf_len > 0);
+ DCHECK(!callback_);
+
+ std::string data(buf->data(), buf_len);
+ net::MockWriteResult write_result = data_->OnWrite(data);
+
+ if (write_result.async) {
+ RunCallbackAsync(callback, write_result.result);
+ return net::ERR_IO_PENDING;
+ }
+ return write_result.result;
+}
+
+class MockSSLClientSocket::ConnectCallback :
+ public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> {
+ public:
+ ConnectCallback(MockSSLClientSocket *ssl_client_socket,
+ net::CompletionCallback* user_callback,
+ int rv)
+ : ALLOW_THIS_IN_INITIALIZER_LIST(
+ net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>(
+ this, &ConnectCallback::Wrapper)),
+ ssl_client_socket_(ssl_client_socket),
+ user_callback_(user_callback),
+ rv_(rv) {
+ }
+
+ private:
+ void Wrapper(int rv) {
+ if (rv_ == net::OK)
+ ssl_client_socket_->connected_ = true;
+ user_callback_->Run(rv_);
+ delete this;
+ }
+
+ MockSSLClientSocket* ssl_client_socket_;
+ net::CompletionCallback* user_callback_;
+ int rv_;
+};
+
+MockSSLClientSocket::MockSSLClientSocket(
+ net::ClientSocket* transport_socket,
+ const std::string& hostname,
+ const net::SSLConfig& ssl_config,
+ net::MockSSLSocket* socket)
+ : transport_(transport_socket),
+ data_(socket) {
+ DCHECK(data_);
+}
+
+MockSSLClientSocket::~MockSSLClientSocket() {
+ Disconnect();
+}
+
+void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
+ ssl_info->Reset();
+}
+
+int MockSSLClientSocket::Connect(net::CompletionCallback* callback) {
+ DCHECK(!callback_);
+ ConnectCallback* connect_callback = new ConnectCallback(
+ this, callback, data_->connect.result);
+ int rv = transport_->Connect(connect_callback);
+ if (rv == net::OK) {
+ delete connect_callback;
+ if (data_->connect.async) {
+ RunCallbackAsync(callback, data_->connect.result);
+ return net::ERR_IO_PENDING;
+ }
+ if (data_->connect.result == net::OK)
+ connected_ = true;
+ return data_->connect.result;
+ }
+ return rv;
+}
+
+void MockSSLClientSocket::Disconnect() {
+ MockClientSocket::Disconnect();
+ if (transport_ != NULL)
+ transport_->Disconnect();
+}
+
+int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback) {
+ DCHECK(!callback_);
+ return transport_->Read(buf, buf_len, callback);
+}
+
+int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len,
+ net::CompletionCallback* callback) {
+ DCHECK(!callback_);
+ return transport_->Write(buf, buf_len, callback);
+}
+
+} // namespace
+
+namespace net {
+
+MockRead* StaticMockSocket::GetNextRead() {
+ return &reads_[read_index_++];
+}
+
+MockWriteResult StaticMockSocket::OnWrite(const std::string& data) {
+ if (!writes_) {
+ // Not using mock writes; succeed synchronously.
+ return MockWriteResult(false, data.length());
+ }
+
+ // Check that what we are writing matches the expectation.
+ // Then give the mocked return value.
+ net::MockWrite* w = &writes_[write_index_++];
+ int result = w->result;
+ if (w->data) {
+ std::string expected_data(w->data, w->data_len);
+ EXPECT_EQ(expected_data, data);
+ if (expected_data != data)
+ return MockWriteResult(false, net::ERR_UNEXPECTED);
+ if (result == net::OK)
+ result = w->data_len;
+ }
+ return MockWriteResult(w->async, result);
+}
+
+void StaticMockSocket::Reset() {
+ read_index_ = 0;
+ write_index_ = 0;
+}
+
+DynamicMockSocket::DynamicMockSocket()
+ : read_(false, ERR_UNEXPECTED),
+ has_read_(false) {
+}
+
+MockRead* DynamicMockSocket::GetNextRead() {
+ if (!has_read_)
+ return unexpected_read();
+ has_read_ = false;
+ return &read_;
+}
+
+void DynamicMockSocket::Reset() {
+ has_read_ = false;
+}
+
+void DynamicMockSocket::SimulateRead(const char* data) {
+ EXPECT_FALSE(has_read_) << "Unconsumed read: " << read_.data;
+ read_ = MockRead(data);
+ has_read_ = true;
+}
+
+void MockClientSocketFactory::AddMockSocket(MockSocket* socket) {
+ mock_sockets_.Add(socket);
+}
+
+void MockClientSocketFactory::AddMockSSLSocket(MockSSLSocket* socket) {
+ mock_ssl_sockets_.Add(socket);
+}
+
+void MockClientSocketFactory::ResetNextMockIndexes() {
+ mock_sockets_.ResetNextIndex();
+ mock_ssl_sockets_.ResetNextIndex();
+}
+
+ClientSocket* MockClientSocketFactory::CreateTCPClientSocket(
+ const AddressList& addresses) {
+ return new MockTCPClientSocket(addresses, mock_sockets_.GetNext());
+}
+
+SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket(
+ ClientSocket* transport_socket,
+ const std::string& hostname,
+ const SSLConfig& ssl_config) {
+ return new MockSSLClientSocket(transport_socket, hostname, ssl_config,
+ mock_ssl_sockets_.GetNext());
+}
+
+} // namespace net