summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorpliard@chromium.org <pliard@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2012-06-01 09:41:00 +0000
committerpliard@chromium.org <pliard@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2012-06-01 09:41:00 +0000
commita7885b94f220458ead7abace4d4fe004faca3989 (patch)
treeedb38efe969d5cb5c947980efb6d9430acf47d9b
parent84d88d337868e1a9dbee08b1c5acc81447d47e36 (diff)
downloadchromium_src-a7885b94f220458ead7abace4d4fe004faca3989.zip
chromium_src-a7885b94f220458ead7abace4d4fe004faca3989.tar.gz
chromium_src-a7885b94f220458ead7abace4d4fe004faca3989.tar.bz2
Add net/base/unix_domain_socket_posix*.
This is part of Chrome for Android upstreaming. It will be used by DevTools on Android. Review URL: https://chromiumcodereview.appspot.com/10391053 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@140010 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r--net/base/stream_listen_socket.cc12
-rw-r--r--net/base/stream_listen_socket.h18
-rw-r--r--net/base/tcp_listen_socket.cc11
-rw-r--r--net/base/tcp_listen_socket.h4
-rw-r--r--net/base/tcp_listen_socket_unittest.cc11
-rw-r--r--net/base/tcp_listen_socket_unittest.h11
-rw-r--r--net/base/unix_domain_socket_posix.cc186
-rw-r--r--net/base/unix_domain_socket_posix.h121
-rw-r--r--net/base/unix_domain_socket_posix_unittest.cc313
-rw-r--r--net/net.gyp5
-rw-r--r--net/tools/fetch/http_listen_socket.cc6
-rw-r--r--net/tools/fetch/http_listen_socket.h2
12 files changed, 660 insertions, 40 deletions
diff --git a/net/base/stream_listen_socket.cc b/net/base/stream_listen_socket.cc
index e5ef6c8..fb28c66 100644
--- a/net/base/stream_listen_socket.cc
+++ b/net/base/stream_listen_socket.cc
@@ -69,14 +69,14 @@ const net::BackoffEntry::Policy kSendBackoffPolicy = {
} // namespace
#if defined(OS_WIN)
-const SOCKET StreamListenSocket::kInvalidSocket = INVALID_SOCKET;
+const SocketDescriptor StreamListenSocket::kInvalidSocket = INVALID_SOCKET;
const int StreamListenSocket::kSocketError = SOCKET_ERROR;
#elif defined(OS_POSIX)
-const SOCKET StreamListenSocket::kInvalidSocket = -1;
+const SocketDescriptor StreamListenSocket::kInvalidSocket = -1;
const int StreamListenSocket::kSocketError = -1;
#endif
-StreamListenSocket::StreamListenSocket(SOCKET s,
+StreamListenSocket::StreamListenSocket(SocketDescriptor s,
StreamListenSocket::Delegate* del)
: socket_delegate_(del),
socket_(s),
@@ -115,8 +115,8 @@ void StreamListenSocket::Send(const string& str, bool append_linefeed) {
Send(str.data(), static_cast<int>(str.length()), append_linefeed);
}
-SOCKET StreamListenSocket::AcceptSocket() {
- SOCKET conn = HANDLE_EINTR(accept(socket_, NULL, NULL));
+SocketDescriptor StreamListenSocket::AcceptSocket() {
+ SocketDescriptor conn = HANDLE_EINTR(accept(socket_, NULL, NULL));
if (conn == kInvalidSocket)
LOG(ERROR) << "Error accepting connection.";
else
@@ -204,7 +204,7 @@ void StreamListenSocket::Close() {
socket_delegate_->DidClose(this);
}
-void StreamListenSocket::CloseSocket(SOCKET s) {
+void StreamListenSocket::CloseSocket(SocketDescriptor s) {
if (s && s != kInvalidSocket) {
UnwatchSocket();
#if defined(OS_WIN)
diff --git a/net/base/stream_listen_socket.h b/net/base/stream_listen_socket.h
index 95c3260..9e88864 100644
--- a/net/base/stream_listen_socket.h
+++ b/net/base/stream_listen_socket.h
@@ -41,7 +41,9 @@
#include "net/base/stream_listen_socket.h"
#if defined(OS_POSIX)
-typedef int SOCKET;
+typedef int SocketDescriptor;
+#else
+typedef SOCKET SocketDescriptor;
#endif
namespace net {
@@ -78,6 +80,9 @@ class NET_EXPORT StreamListenSocket
void Send(const char* bytes, int len, bool append_linefeed = false);
void Send(const std::string& str, bool append_linefeed = false);
+ static const SocketDescriptor kInvalidSocket;
+ static const int kSocketError;
+
protected:
enum WaitState {
NOT_WAITING = 0,
@@ -85,19 +90,16 @@ class NET_EXPORT StreamListenSocket
WAITING_READ = 2
};
- static const SOCKET kInvalidSocket;
- static const int kSocketError;
-
- StreamListenSocket(SOCKET s, Delegate* del);
+ StreamListenSocket(SocketDescriptor s, Delegate* del);
virtual ~StreamListenSocket();
- SOCKET AcceptSocket();
+ SocketDescriptor AcceptSocket();
virtual void Accept() = 0;
void Listen();
void Read();
void Close();
- void CloseSocket(SOCKET s);
+ void CloseSocket(SocketDescriptor s);
// Pass any value in case of Windows, because in Windows
// we are not using state.
@@ -133,7 +135,7 @@ class NET_EXPORT StreamListenSocket
void PauseReads();
void ResumeReads();
- const SOCKET socket_;
+ const SocketDescriptor socket_;
bool reads_paused_;
bool has_pending_reads_;
diff --git a/net/base/tcp_listen_socket.cc b/net/base/tcp_listen_socket.cc
index 1583bfb..513297c 100644
--- a/net/base/tcp_listen_socket.cc
+++ b/net/base/tcp_listen_socket.cc
@@ -31,7 +31,7 @@ namespace net {
// static
scoped_refptr<TCPListenSocket> TCPListenSocket::CreateAndListen(
const string& ip, int port, StreamListenSocket::Delegate* del) {
- SOCKET s = CreateAndBind(ip, port);
+ SocketDescriptor s = CreateAndBind(ip, port);
if (s == kInvalidSocket)
return NULL;
scoped_refptr<TCPListenSocket> sock(new TCPListenSocket(s, del));
@@ -39,18 +39,19 @@ scoped_refptr<TCPListenSocket> TCPListenSocket::CreateAndListen(
return sock;
}
-TCPListenSocket::TCPListenSocket(SOCKET s, StreamListenSocket::Delegate* del)
+TCPListenSocket::TCPListenSocket(SocketDescriptor s,
+ StreamListenSocket::Delegate* del)
: StreamListenSocket(s, del) {
}
TCPListenSocket::~TCPListenSocket() {}
-SOCKET TCPListenSocket::CreateAndBind(const string& ip, int port) {
+SocketDescriptor TCPListenSocket::CreateAndBind(const string& ip, int port) {
#if defined(OS_WIN)
EnsureWinsockInit();
#endif
- SOCKET s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ SocketDescriptor s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (s != kInvalidSocket) {
#if defined(OS_POSIX)
// Allow rapid reuse.
@@ -76,7 +77,7 @@ SOCKET TCPListenSocket::CreateAndBind(const string& ip, int port) {
}
void TCPListenSocket::Accept() {
- SOCKET conn = AcceptSocket();
+ SocketDescriptor conn = AcceptSocket();
if (conn == kInvalidSocket)
return;
scoped_refptr<TCPListenSocket> sock(
diff --git a/net/base/tcp_listen_socket.h b/net/base/tcp_listen_socket.h
index 6691b5f..7398565 100644
--- a/net/base/tcp_listen_socket.h
+++ b/net/base/tcp_listen_socket.h
@@ -26,10 +26,10 @@ class NET_EXPORT TCPListenSocket : public StreamListenSocket {
protected:
friend class scoped_refptr<TCPListenSocket>;
- TCPListenSocket(SOCKET s, StreamListenSocket::Delegate* del);
+ TCPListenSocket(SocketDescriptor s, StreamListenSocket::Delegate* del);
virtual ~TCPListenSocket();
- static SOCKET CreateAndBind(const std::string& ip, int port);
+ static SocketDescriptor CreateAndBind(const std::string& ip, int port);
// Implements StreamListenSocket::Accept.
virtual void Accept() OVERRIDE;
diff --git a/net/base/tcp_listen_socket_unittest.cc b/net/base/tcp_listen_socket_unittest.cc
index 61b259c..677a804 100644
--- a/net/base/tcp_listen_socket_unittest.cc
+++ b/net/base/tcp_listen_socket_unittest.cc
@@ -48,7 +48,7 @@ void TCPListenSocketTester::SetUp() {
// verify the connect/accept and setup test_socket_
test_socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
- ASSERT_NE(INVALID_SOCKET, test_socket_);
+ ASSERT_NE(StreamListenSocket::kInvalidSocket, test_socket_);
struct sockaddr_in client;
client.sin_family = AF_INET;
client.sin_addr.s_addr = inet_addr(kLoopback);
@@ -56,7 +56,7 @@ void TCPListenSocketTester::SetUp() {
int ret = HANDLE_EINTR(
connect(test_socket_, reinterpret_cast<sockaddr*>(&client),
sizeof(client)));
- ASSERT_NE(ret, SOCKET_ERROR);
+ ASSERT_NE(ret, StreamListenSocket::kSocketError);
NextAction();
ASSERT_EQ(ACTION_ACCEPT, last_action_.type());
@@ -100,7 +100,7 @@ int TCPListenSocketTester::ClearTestSocket() {
int len_ret = 0;
do {
int len = HANDLE_EINTR(recv(test_socket_, buf, kReadBufSize, 0));
- if (len == SOCKET_ERROR || len == 0) {
+ if (len == StreamListenSocket::kSocketError || len == 0) {
break;
} else {
len_ret += len;
@@ -210,10 +210,11 @@ void TCPListenSocketTester::TestServerSendMultiple() {
}
}
-bool TCPListenSocketTester::Send(SOCKET sock, const std::string& str) {
+bool TCPListenSocketTester::Send(SocketDescriptor sock,
+ const std::string& str) {
int len = static_cast<int>(str.length());
int send_len = HANDLE_EINTR(send(sock, str.data(), len, 0));
- if (send_len == SOCKET_ERROR) {
+ if (send_len == StreamListenSocket::kSocketError) {
LOG(ERROR) << "send failed: " << errno;
return false;
} else if (send_len != len) {
diff --git a/net/base/tcp_listen_socket_unittest.h b/net/base/tcp_listen_socket_unittest.h
index 71870f7..e1069b6 100644
--- a/net/base/tcp_listen_socket_unittest.h
+++ b/net/base/tcp_listen_socket_unittest.h
@@ -29,13 +29,6 @@
#include "net/base/winsock_init.h"
#include "testing/gtest/include/gtest/gtest.h"
-#if defined(OS_POSIX)
-// Used same name as in Windows to avoid #ifdef where referenced
-#define SOCKET int
-const int INVALID_SOCKET = -1;
-const int SOCKET_ERROR = -1;
-#endif
-
namespace net {
enum ActionType {
@@ -95,7 +88,7 @@ class TCPListenSocketTester :
// verify multiple sends and reads from server to client.
void TestServerSendMultiple();
- virtual bool Send(SOCKET sock, const std::string& str);
+ virtual bool Send(SocketDescriptor sock, const std::string& str);
// StreamListenSocket::Delegate:
virtual void DidAccept(StreamListenSocket* server,
@@ -110,7 +103,7 @@ class TCPListenSocketTester :
StreamListenSocket* connection_;
TCPListenSocketTestAction last_action_;
- SOCKET test_socket_;
+ SocketDescriptor test_socket_;
static const int kTestPort;
base::Lock lock_; // protects |queue_| and wraps |cv_|
diff --git a/net/base/unix_domain_socket_posix.cc b/net/base/unix_domain_socket_posix.cc
new file mode 100644
index 0000000..d852ed8
--- /dev/null
+++ b/net/base/unix_domain_socket_posix.cc
@@ -0,0 +1,186 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/base/unix_domain_socket_posix.h"
+
+#include <cstring>
+#include <string>
+
+#include <errno.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include "base/bind.h"
+#include "base/callback.h"
+#include "base/eintr_wrapper.h"
+#include "base/threading/platform_thread.h"
+#include "build/build_config.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+
+namespace net {
+
+namespace {
+
+bool NoAuthenticationCallback(uid_t, gid_t) {
+ return true;
+}
+
+bool GetPeerIds(int socket, uid_t* user_id, gid_t* group_id) {
+#if defined(OS_LINUX) || defined(OS_ANDROID)
+ struct ucred user_cred;
+ socklen_t len = sizeof(user_cred);
+ if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) == -1)
+ return false;
+ *user_id = user_cred.uid;
+ *group_id = user_cred.gid;
+#else
+ if (getpeereid(socket, user_id, group_id) == -1)
+ return false;
+#endif
+ return true;
+}
+
+} // namespace
+
+// static
+UnixDomainSocket::AuthCallback NoAuthentication() {
+ return base::Bind(NoAuthenticationCallback);
+}
+
+// static
+UnixDomainSocket* UnixDomainSocket::CreateAndListenInternal(
+ const std::string& path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback,
+ bool use_abstract_namespace) {
+ SocketDescriptor s = CreateAndBind(path, use_abstract_namespace);
+ if (s == kInvalidSocket)
+ return NULL;
+ UnixDomainSocket* sock = new UnixDomainSocket(s, del, auth_callback);
+ sock->Listen();
+ return sock;
+}
+
+// static
+scoped_refptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen(
+ const std::string& path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback) {
+ return CreateAndListenInternal(path, del, auth_callback, false);
+}
+
+#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
+// static
+scoped_refptr<UnixDomainSocket>
+UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+ const std::string& path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback) {
+ return make_scoped_refptr(
+ CreateAndListenInternal(path, del, auth_callback, true));
+}
+#endif
+
+UnixDomainSocket::UnixDomainSocket(
+ SocketDescriptor s,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback)
+ : StreamListenSocket(s, del),
+ auth_callback_(auth_callback) {}
+
+UnixDomainSocket::~UnixDomainSocket() {}
+
+// static
+SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path,
+ bool use_abstract_namespace) {
+ sockaddr_un addr;
+ static const size_t kPathMax = sizeof(addr.sun_path);
+ if (use_abstract_namespace + path.size() + 1 /* '\0' */ > kPathMax)
+ return kInvalidSocket;
+ const SocketDescriptor s = socket(PF_UNIX, SOCK_STREAM, 0);
+ if (s == kInvalidSocket)
+ return kInvalidSocket;
+ memset(&addr, 0, sizeof(addr));
+ addr.sun_family = AF_UNIX;
+ socklen_t addr_len;
+ if (use_abstract_namespace) {
+ // Convert the path given into abstract socket name. It must start with
+ // the '\0' character, so we are adding it. |addr_len| must specify the
+ // length of the structure exactly, as potentially the socket name may
+ // have '\0' characters embedded (although we don't support this).
+ // Note that addr.sun_path is already zero initialized.
+ memcpy(addr.sun_path + 1, path.c_str(), path.size());
+ addr_len = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
+ } else {
+ memcpy(addr.sun_path, path.c_str(), path.size());
+ addr_len = sizeof(sockaddr_un);
+ }
+ if (bind(s, reinterpret_cast<sockaddr*>(&addr), addr_len)) {
+ LOG(ERROR) << "Could not bind unix domain socket to " << path;
+ if (use_abstract_namespace)
+ LOG(ERROR) << " (with abstract namespace enabled)";
+ if (HANDLE_EINTR(close(s)) < 0)
+ LOG(ERROR) << "close() error";
+ return kInvalidSocket;
+ }
+ return s;
+}
+
+void UnixDomainSocket::Accept() {
+ SocketDescriptor conn = StreamListenSocket::AcceptSocket();
+ if (conn == kInvalidSocket)
+ return;
+ uid_t user_id;
+ gid_t group_id;
+ if (!GetPeerIds(conn, &user_id, &group_id) ||
+ !auth_callback_.Run(user_id, group_id)) {
+ if (HANDLE_EINTR(close(conn)) < 0)
+ LOG(ERROR) << "close() error";
+ return;
+ }
+ scoped_refptr<UnixDomainSocket> sock(
+ new UnixDomainSocket(conn, socket_delegate_, auth_callback_));
+ // It's up to the delegate to AddRef if it wants to keep it around.
+ sock->WatchSocket(WAITING_READ);
+ socket_delegate_->DidAccept(this, sock);
+}
+
+UnixDomainSocketFactory::UnixDomainSocketFactory(
+ const std::string& path,
+ const UnixDomainSocket::AuthCallback& auth_callback)
+ : path_(path),
+ auth_callback_(auth_callback) {}
+
+UnixDomainSocketFactory::~UnixDomainSocketFactory() {}
+
+scoped_refptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const {
+ return UnixDomainSocket::CreateAndListen(path_, delegate, auth_callback_);
+}
+
+#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
+
+UnixDomainSocketWithAbstractNamespaceFactory::
+UnixDomainSocketWithAbstractNamespaceFactory(
+ const std::string& path,
+ const UnixDomainSocket::AuthCallback& auth_callback)
+ : UnixDomainSocketFactory(path, auth_callback) {}
+
+UnixDomainSocketWithAbstractNamespaceFactory::
+~UnixDomainSocketWithAbstractNamespaceFactory() {}
+
+scoped_refptr<StreamListenSocket>
+UnixDomainSocketWithAbstractNamespaceFactory::CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const {
+ return UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+ path_, delegate, auth_callback_);
+}
+
+#endif
+
+} // namespace net
diff --git a/net/base/unix_domain_socket_posix.h b/net/base/unix_domain_socket_posix.h
new file mode 100644
index 0000000..7730900
--- /dev/null
+++ b/net/base/unix_domain_socket_posix.h
@@ -0,0 +1,121 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_BASE_UNIX_DOMAIN_SOCKET_POSIX_H_
+#define NET_BASE_UNIX_DOMAIN_SOCKET_POSIX_H_
+#pragma once
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/callback_forward.h"
+#include "base/compiler_specific.h"
+#include "base/memory/ref_counted.h"
+#include "build/build_config.h"
+#include "net/base/net_export.h"
+#include "net/base/stream_listen_socket.h"
+
+#if defined(OS_ANDROID) || defined(OS_LINUX)
+// Feature only supported on Linux currently. This lets the Unix Domain Socket
+// not be backed by the file system.
+#define SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
+#endif
+
+namespace net {
+
+// Unix Domain Socket Implementation. Supports abstract namespaces on Linux.
+class NET_EXPORT UnixDomainSocket : public StreamListenSocket {
+ public:
+ // Callback that returns whether the already connected client, identified by
+ // its process |user_id| and |group_id|, is allowed to keep the connection
+ // open. Note that the socket is closed immediately in case the callback
+ // returns false.
+ typedef base::Callback<bool (uid_t user_id, gid_t group_id)> AuthCallback;
+
+ // Returns an authentication callback that always grants access for
+ // convenience in case you don't want to use authentication.
+ static AuthCallback NoAuthentication();
+
+ // Note that the returned UnixDomainSocket instance does not take ownership of
+ // |del|.
+ static scoped_refptr<UnixDomainSocket> CreateAndListen(
+ const std::string& path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback);
+
+#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
+ // Same as above except that the created socket uses the abstract namespace
+ // which is a Linux-only feature.
+ static scoped_refptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace(
+ const std::string& path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback);
+#endif
+
+ private:
+ UnixDomainSocket(SocketDescriptor s,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback);
+ virtual ~UnixDomainSocket();
+
+ static UnixDomainSocket* CreateAndListenInternal(
+ const std::string& path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback,
+ bool use_abstract_namespace);
+
+ static SocketDescriptor CreateAndBind(const std::string& path,
+ bool use_abstract_namespace);
+
+ // StreamListenSocket:
+ virtual void Accept() OVERRIDE;
+
+ AuthCallback auth_callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(UnixDomainSocket);
+};
+
+// Factory that can be used to instantiate UnixDomainSocket.
+class NET_EXPORT UnixDomainSocketFactory : public StreamListenSocketFactory {
+ public:
+ // Note that this class does not take ownership of the provided delegate.
+ UnixDomainSocketFactory(const std::string& path,
+ const UnixDomainSocket::AuthCallback& auth_callback);
+ virtual ~UnixDomainSocketFactory();
+
+ // StreamListenSocketFactory:
+ virtual scoped_refptr<StreamListenSocket> CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const OVERRIDE;
+
+ protected:
+ const std::string path_;
+ const UnixDomainSocket::AuthCallback auth_callback_;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketFactory);
+};
+
+#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
+// Use this factory to instantiate UnixDomainSocket using the abstract
+// namespace feature (only supported on Linux).
+class NET_EXPORT UnixDomainSocketWithAbstractNamespaceFactory
+ : public UnixDomainSocketFactory {
+ public:
+ UnixDomainSocketWithAbstractNamespaceFactory(
+ const std::string& path,
+ const UnixDomainSocket::AuthCallback& auth_callback);
+ virtual ~UnixDomainSocketWithAbstractNamespaceFactory();
+
+ // UnixDomainSocketFactory:
+ virtual scoped_refptr<StreamListenSocket> CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const OVERRIDE;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketWithAbstractNamespaceFactory);
+};
+#endif
+
+} // namespace net
+
+#endif // NET_BASE_UNIX_DOMAIN_SOCKET_POSIX_H_
diff --git a/net/base/unix_domain_socket_posix_unittest.cc b/net/base/unix_domain_socket_posix_unittest.cc
new file mode 100644
index 0000000..e72ac59
--- /dev/null
+++ b/net/base/unix_domain_socket_posix_unittest.cc
@@ -0,0 +1,313 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <poll.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstring>
+#include <queue>
+#include <string>
+
+#include "base/bind.h"
+#include "base/callback.h"
+#include "base/compiler_specific.h"
+#include "base/eintr_wrapper.h"
+#include "base/file_path.h"
+#include "base/file_util.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/message_loop.h"
+#include "base/synchronization/condition_variable.h"
+#include "base/synchronization/lock.h"
+#include "base/threading/platform_thread.h"
+#include "base/threading/thread.h"
+#include "net/base/unix_domain_socket_posix.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+using std::queue;
+using std::string;
+
+namespace net {
+namespace {
+
+const char kSocketFilename[] = "unix_domain_socket_for_testing";
+const char kInvalidSocketPath[] = "/invalid/path";
+const char kMsg[] = "hello";
+
+enum EventType {
+ EVENT_ACCEPT,
+ EVENT_AUTH_DENIED,
+ EVENT_AUTH_GRANTED,
+ EVENT_CLOSE,
+ EVENT_LISTEN,
+ EVENT_READ,
+};
+
+string MakeSocketPath() {
+ FilePath temp_dir;
+ file_util::GetTempDir(&temp_dir);
+ return temp_dir.Append(kSocketFilename).value();
+}
+
+class EventManager : public base::RefCounted<EventManager> {
+ public:
+ EventManager() : condition_(&mutex_) {}
+
+ bool HasPendingEvent() {
+ base::AutoLock lock(mutex_);
+ return !events_.empty();
+ }
+
+ void Notify(EventType event) {
+ base::AutoLock lock(mutex_);
+ events_.push(event);
+ condition_.Broadcast();
+ }
+
+ EventType WaitForEvent() {
+ base::AutoLock lock(mutex_);
+ while (events_.empty())
+ condition_.Wait();
+ EventType event = events_.front();
+ events_.pop();
+ return event;
+ }
+
+ private:
+ friend class base::RefCounted<EventManager>;
+ virtual ~EventManager() {}
+
+ queue<EventType> events_;
+ base::Lock mutex_;
+ base::ConditionVariable condition_;
+};
+
+class TestListenSocketDelegate : public StreamListenSocket::Delegate {
+ public:
+ explicit TestListenSocketDelegate(
+ const scoped_refptr<EventManager>& event_manager)
+ : event_manager_(event_manager) {}
+
+ virtual void DidAccept(StreamListenSocket* server,
+ StreamListenSocket* connection) OVERRIDE {
+ LOG(ERROR) << __PRETTY_FUNCTION__;
+ connection_ = connection;
+ Notify(EVENT_ACCEPT);
+ }
+
+ virtual void DidRead(StreamListenSocket* connection,
+ const char* data,
+ int len) OVERRIDE {
+ {
+ base::AutoLock lock(mutex_);
+ DCHECK(len);
+ data_.assign(data, len - 1);
+ }
+ Notify(EVENT_READ);
+ }
+
+ virtual void DidClose(StreamListenSocket* sock) OVERRIDE {
+ Notify(EVENT_CLOSE);
+ }
+
+ void OnListenCompleted() {
+ Notify(EVENT_LISTEN);
+ }
+
+ string ReceivedData() {
+ base::AutoLock lock(mutex_);
+ return data_;
+ }
+
+ private:
+ void Notify(EventType event) {
+ event_manager_->Notify(event);
+ }
+
+ const scoped_refptr<EventManager> event_manager_;
+ scoped_refptr<StreamListenSocket> connection_;
+ base::Lock mutex_;
+ string data_;
+};
+
+bool UserCanConnectCallback(
+ bool allow_user, const scoped_refptr<EventManager>& event_manager,
+ uid_t, gid_t) {
+ event_manager->Notify(
+ allow_user ? EVENT_AUTH_GRANTED : EVENT_AUTH_DENIED);
+ return allow_user;
+}
+
+class UnixDomainSocketTestHelper : public testing::Test {
+ public:
+ void CreateAndListen() {
+ socket_ = UnixDomainSocket::CreateAndListen(
+ file_path_.value(), socket_delegate_.get(), MakeAuthCallback());
+ socket_delegate_->OnListenCompleted();
+ }
+
+ protected:
+ UnixDomainSocketTestHelper(const string& path, bool allow_user)
+ : file_path_(path),
+ allow_user_(allow_user) {}
+
+ virtual void SetUp() OVERRIDE {
+ event_manager_ = new EventManager();
+ socket_delegate_.reset(new TestListenSocketDelegate(event_manager_));
+ DeleteSocketFile();
+ }
+
+ virtual void TearDown() OVERRIDE {
+ DeleteSocketFile();
+ socket_ = NULL;
+ socket_delegate_.reset();
+ event_manager_ = NULL;
+ }
+
+ UnixDomainSocket::AuthCallback MakeAuthCallback() {
+ return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_);
+ }
+
+ void DeleteSocketFile() {
+ ASSERT_FALSE(file_path_.empty());
+ file_util::Delete(file_path_, false /* not recursive */);
+ }
+
+ SocketDescriptor CreateClientSocket() {
+ const SocketDescriptor sock = socket(PF_UNIX, SOCK_STREAM, 0);
+ if (sock < 0) {
+ LOG(ERROR) << "socket() error";
+ return StreamListenSocket::kInvalidSocket;
+ }
+ sockaddr_un addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.sun_family = AF_UNIX;
+ socklen_t addr_len;
+ strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path));
+ addr_len = sizeof(sockaddr_un);
+ if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) {
+ LOG(ERROR) << "connect() error";
+ return StreamListenSocket::kInvalidSocket;
+ }
+ return sock;
+ }
+
+ scoped_ptr<base::Thread> CreateAndRunServerThread() {
+ base::Thread::Options options;
+ options.message_loop_type = MessageLoop::TYPE_IO;
+ scoped_ptr<base::Thread> thread(new base::Thread("socketio_test"));
+ thread->StartWithOptions(options);
+ thread->message_loop()->PostTask(
+ FROM_HERE,
+ base::Bind(&UnixDomainSocketTestHelper::CreateAndListen,
+ base::Unretained(this)));
+ return thread.Pass();
+ }
+
+ const FilePath file_path_;
+ const bool allow_user_;
+ scoped_refptr<EventManager> event_manager_;
+ scoped_ptr<TestListenSocketDelegate> socket_delegate_;
+ scoped_refptr<UnixDomainSocket> socket_;
+};
+
+class UnixDomainSocketTest : public UnixDomainSocketTestHelper {
+ protected:
+ UnixDomainSocketTest()
+ : UnixDomainSocketTestHelper(MakeSocketPath(), true /* allow user */) {}
+};
+
+class UnixDomainSocketTestWithInvalidPath : public UnixDomainSocketTestHelper {
+ protected:
+ UnixDomainSocketTestWithInvalidPath()
+ : UnixDomainSocketTestHelper(kInvalidSocketPath, true) {}
+};
+
+class UnixDomainSocketTestWithForbiddenUser
+ : public UnixDomainSocketTestHelper {
+ protected:
+ UnixDomainSocketTestWithForbiddenUser()
+ : UnixDomainSocketTestHelper(MakeSocketPath(), false /* forbid user */) {}
+};
+
+TEST_F(UnixDomainSocketTest, CreateAndListen) {
+ CreateAndListen();
+ EXPECT_FALSE(socket_.get() == NULL);
+}
+
+TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) {
+ CreateAndListen();
+ EXPECT_TRUE(socket_.get() == NULL);
+}
+
+#ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
+// Test with an invalid path to make sure that the socket is not backed by a
+// file.
+TEST_F(UnixDomainSocketTestWithInvalidPath,
+ CreateAndListenWithAbstractNamespace) {
+ socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+ file_path_.value(), socket_delegate_.get(), MakeAuthCallback());
+ EXPECT_FALSE(socket_.get() == NULL);
+}
+#endif
+
+TEST_F(UnixDomainSocketTest, TestWithClient) {
+ const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
+ EventType event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_LISTEN, event);
+
+ // Create the client socket.
+ const SocketDescriptor sock = CreateClientSocket();
+ ASSERT_NE(StreamListenSocket::kInvalidSocket, sock);
+ event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_AUTH_GRANTED, event);
+ event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_ACCEPT, event);
+
+ // Send a message from the client to the server.
+ ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0));
+ ASSERT_NE(-1, ret);
+ ASSERT_EQ(sizeof(kMsg), static_cast<size_t>(ret));
+ event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_READ, event);
+ ASSERT_EQ(kMsg, socket_delegate_->ReceivedData());
+
+ // Close the client socket.
+ ret = HANDLE_EINTR(close(sock));
+ event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_CLOSE, event);
+}
+
+TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) {
+ const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
+ EventType event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_LISTEN, event);
+ const SocketDescriptor sock = CreateClientSocket();
+ ASSERT_NE(StreamListenSocket::kInvalidSocket, sock);
+
+ event = event_manager_->WaitForEvent();
+ ASSERT_EQ(EVENT_AUTH_DENIED, event);
+
+ // Wait until the file descriptor is closed by the server.
+ struct pollfd poll_fd;
+ poll_fd.fd = sock;
+ poll_fd.events = POLLIN;
+ poll(&poll_fd, 1, -1 /* rely on GTest for timeout handling */);
+
+ // Send() must fail.
+ ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0));
+ ASSERT_EQ(-1, ret);
+ ASSERT_EQ(EPIPE, errno);
+ ASSERT_FALSE(event_manager_->HasPendingEvent());
+}
+
+} // namespace
+} // namespace net
diff --git a/net/net.gyp b/net/net.gyp
index d1a3b46..a2d1d55 100644
--- a/net/net.gyp
+++ b/net/net.gyp
@@ -261,6 +261,8 @@
'base/transport_security_state.cc',
'base/transport_security_state.h',
'base/transport_security_state_static.h',
+ 'base/unix_domain_socket_posix.cc',
+ 'base/unix_domain_socket_posix.h',
'base/upload_data.cc',
'base/upload_data.h',
'base/upload_data_stream.cc',
@@ -1091,8 +1093,9 @@
'base/test_certificate_data.h',
'base/test_completion_callback_unittest.cc',
'base/transport_security_state_unittest.cc',
- 'base/upload_data_unittest.cc',
+ 'base/unix_domain_socket_posix_unittest.cc',
'base/upload_data_stream_unittest.cc',
+ 'base/upload_data_unittest.cc',
'base/x509_certificate_unittest.cc',
'base/x509_cert_types_unittest.cc',
'base/x509_util_nss_unittest.cc',
diff --git a/net/tools/fetch/http_listen_socket.cc b/net/tools/fetch/http_listen_socket.cc
index dbb34e6..ddad463 100644
--- a/net/tools/fetch/http_listen_socket.cc
+++ b/net/tools/fetch/http_listen_socket.cc
@@ -13,7 +13,7 @@
#include "net/tools/fetch/http_server_request_info.h"
#include "net/tools/fetch/http_server_response_info.h"
-HttpListenSocket::HttpListenSocket(SOCKET s,
+HttpListenSocket::HttpListenSocket(SocketDescriptor s,
HttpListenSocket::Delegate* delegate)
: ALLOW_THIS_IN_INITIALIZER_LIST(net::TCPListenSocket(s, this)),
delegate_(delegate) {
@@ -23,7 +23,7 @@ HttpListenSocket::~HttpListenSocket() {
}
void HttpListenSocket::Accept() {
- SOCKET conn = net::TCPListenSocket::AcceptSocket();
+ SocketDescriptor conn = net::TCPListenSocket::AcceptSocket();
DCHECK_NE(conn, net::TCPListenSocket::kInvalidSocket);
if (conn == net::TCPListenSocket::kInvalidSocket) {
// TODO
@@ -40,7 +40,7 @@ scoped_refptr<HttpListenSocket> HttpListenSocket::CreateAndListen(
const std::string& ip,
int port,
HttpListenSocket::Delegate* delegate) {
- SOCKET s = net::TCPListenSocket::CreateAndBind(ip, port);
+ SocketDescriptor s = net::TCPListenSocket::CreateAndBind(ip, port);
if (s == net::TCPListenSocket::kInvalidSocket) {
// TODO (ibrar): error handling.
} else {
diff --git a/net/tools/fetch/http_listen_socket.h b/net/tools/fetch/http_listen_socket.h
index 4c3b27b..5eaf7a9 100644
--- a/net/tools/fetch/http_listen_socket.h
+++ b/net/tools/fetch/http_listen_socket.h
@@ -50,7 +50,7 @@ class HttpListenSocket : public net::TCPListenSocket,
static const int kReadBufSize = 16 * 1024;
// Must run in the IO thread.
- HttpListenSocket(SOCKET s, HttpListenSocket::Delegate* del);
+ HttpListenSocket(SocketDescriptor s, HttpListenSocket::Delegate* del);
virtual ~HttpListenSocket();
// Expects the raw data to be stored in recv_data_. If parsing is successful,