diff options
author | pliard@chromium.org <pliard@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2012-06-01 09:41:00 +0000 |
---|---|---|
committer | pliard@chromium.org <pliard@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2012-06-01 09:41:00 +0000 |
commit | a7885b94f220458ead7abace4d4fe004faca3989 (patch) | |
tree | edb38efe969d5cb5c947980efb6d9430acf47d9b | |
parent | 84d88d337868e1a9dbee08b1c5acc81447d47e36 (diff) | |
download | chromium_src-a7885b94f220458ead7abace4d4fe004faca3989.zip chromium_src-a7885b94f220458ead7abace4d4fe004faca3989.tar.gz chromium_src-a7885b94f220458ead7abace4d4fe004faca3989.tar.bz2 |
Add net/base/unix_domain_socket_posix*.
This is part of Chrome for Android upstreaming.
It will be used by DevTools on Android.
Review URL: https://chromiumcodereview.appspot.com/10391053
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@140010 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r-- | net/base/stream_listen_socket.cc | 12 | ||||
-rw-r--r-- | net/base/stream_listen_socket.h | 18 | ||||
-rw-r--r-- | net/base/tcp_listen_socket.cc | 11 | ||||
-rw-r--r-- | net/base/tcp_listen_socket.h | 4 | ||||
-rw-r--r-- | net/base/tcp_listen_socket_unittest.cc | 11 | ||||
-rw-r--r-- | net/base/tcp_listen_socket_unittest.h | 11 | ||||
-rw-r--r-- | net/base/unix_domain_socket_posix.cc | 186 | ||||
-rw-r--r-- | net/base/unix_domain_socket_posix.h | 121 | ||||
-rw-r--r-- | net/base/unix_domain_socket_posix_unittest.cc | 313 | ||||
-rw-r--r-- | net/net.gyp | 5 | ||||
-rw-r--r-- | net/tools/fetch/http_listen_socket.cc | 6 | ||||
-rw-r--r-- | net/tools/fetch/http_listen_socket.h | 2 |
12 files changed, 660 insertions, 40 deletions
diff --git a/net/base/stream_listen_socket.cc b/net/base/stream_listen_socket.cc index e5ef6c8..fb28c66 100644 --- a/net/base/stream_listen_socket.cc +++ b/net/base/stream_listen_socket.cc @@ -69,14 +69,14 @@ const net::BackoffEntry::Policy kSendBackoffPolicy = { } // namespace #if defined(OS_WIN) -const SOCKET StreamListenSocket::kInvalidSocket = INVALID_SOCKET; +const SocketDescriptor StreamListenSocket::kInvalidSocket = INVALID_SOCKET; const int StreamListenSocket::kSocketError = SOCKET_ERROR; #elif defined(OS_POSIX) -const SOCKET StreamListenSocket::kInvalidSocket = -1; +const SocketDescriptor StreamListenSocket::kInvalidSocket = -1; const int StreamListenSocket::kSocketError = -1; #endif -StreamListenSocket::StreamListenSocket(SOCKET s, +StreamListenSocket::StreamListenSocket(SocketDescriptor s, StreamListenSocket::Delegate* del) : socket_delegate_(del), socket_(s), @@ -115,8 +115,8 @@ void StreamListenSocket::Send(const string& str, bool append_linefeed) { Send(str.data(), static_cast<int>(str.length()), append_linefeed); } -SOCKET StreamListenSocket::AcceptSocket() { - SOCKET conn = HANDLE_EINTR(accept(socket_, NULL, NULL)); +SocketDescriptor StreamListenSocket::AcceptSocket() { + SocketDescriptor conn = HANDLE_EINTR(accept(socket_, NULL, NULL)); if (conn == kInvalidSocket) LOG(ERROR) << "Error accepting connection."; else @@ -204,7 +204,7 @@ void StreamListenSocket::Close() { socket_delegate_->DidClose(this); } -void StreamListenSocket::CloseSocket(SOCKET s) { +void StreamListenSocket::CloseSocket(SocketDescriptor s) { if (s && s != kInvalidSocket) { UnwatchSocket(); #if defined(OS_WIN) diff --git a/net/base/stream_listen_socket.h b/net/base/stream_listen_socket.h index 95c3260..9e88864 100644 --- a/net/base/stream_listen_socket.h +++ b/net/base/stream_listen_socket.h @@ -41,7 +41,9 @@ #include "net/base/stream_listen_socket.h" #if defined(OS_POSIX) -typedef int SOCKET; +typedef int SocketDescriptor; +#else +typedef SOCKET SocketDescriptor; #endif namespace net { @@ -78,6 +80,9 @@ class NET_EXPORT StreamListenSocket void Send(const char* bytes, int len, bool append_linefeed = false); void Send(const std::string& str, bool append_linefeed = false); + static const SocketDescriptor kInvalidSocket; + static const int kSocketError; + protected: enum WaitState { NOT_WAITING = 0, @@ -85,19 +90,16 @@ class NET_EXPORT StreamListenSocket WAITING_READ = 2 }; - static const SOCKET kInvalidSocket; - static const int kSocketError; - - StreamListenSocket(SOCKET s, Delegate* del); + StreamListenSocket(SocketDescriptor s, Delegate* del); virtual ~StreamListenSocket(); - SOCKET AcceptSocket(); + SocketDescriptor AcceptSocket(); virtual void Accept() = 0; void Listen(); void Read(); void Close(); - void CloseSocket(SOCKET s); + void CloseSocket(SocketDescriptor s); // Pass any value in case of Windows, because in Windows // we are not using state. @@ -133,7 +135,7 @@ class NET_EXPORT StreamListenSocket void PauseReads(); void ResumeReads(); - const SOCKET socket_; + const SocketDescriptor socket_; bool reads_paused_; bool has_pending_reads_; diff --git a/net/base/tcp_listen_socket.cc b/net/base/tcp_listen_socket.cc index 1583bfb..513297c 100644 --- a/net/base/tcp_listen_socket.cc +++ b/net/base/tcp_listen_socket.cc @@ -31,7 +31,7 @@ namespace net { // static scoped_refptr<TCPListenSocket> TCPListenSocket::CreateAndListen( const string& ip, int port, StreamListenSocket::Delegate* del) { - SOCKET s = CreateAndBind(ip, port); + SocketDescriptor s = CreateAndBind(ip, port); if (s == kInvalidSocket) return NULL; scoped_refptr<TCPListenSocket> sock(new TCPListenSocket(s, del)); @@ -39,18 +39,19 @@ scoped_refptr<TCPListenSocket> TCPListenSocket::CreateAndListen( return sock; } -TCPListenSocket::TCPListenSocket(SOCKET s, StreamListenSocket::Delegate* del) +TCPListenSocket::TCPListenSocket(SocketDescriptor s, + StreamListenSocket::Delegate* del) : StreamListenSocket(s, del) { } TCPListenSocket::~TCPListenSocket() {} -SOCKET TCPListenSocket::CreateAndBind(const string& ip, int port) { +SocketDescriptor TCPListenSocket::CreateAndBind(const string& ip, int port) { #if defined(OS_WIN) EnsureWinsockInit(); #endif - SOCKET s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + SocketDescriptor s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (s != kInvalidSocket) { #if defined(OS_POSIX) // Allow rapid reuse. @@ -76,7 +77,7 @@ SOCKET TCPListenSocket::CreateAndBind(const string& ip, int port) { } void TCPListenSocket::Accept() { - SOCKET conn = AcceptSocket(); + SocketDescriptor conn = AcceptSocket(); if (conn == kInvalidSocket) return; scoped_refptr<TCPListenSocket> sock( diff --git a/net/base/tcp_listen_socket.h b/net/base/tcp_listen_socket.h index 6691b5f..7398565 100644 --- a/net/base/tcp_listen_socket.h +++ b/net/base/tcp_listen_socket.h @@ -26,10 +26,10 @@ class NET_EXPORT TCPListenSocket : public StreamListenSocket { protected: friend class scoped_refptr<TCPListenSocket>; - TCPListenSocket(SOCKET s, StreamListenSocket::Delegate* del); + TCPListenSocket(SocketDescriptor s, StreamListenSocket::Delegate* del); virtual ~TCPListenSocket(); - static SOCKET CreateAndBind(const std::string& ip, int port); + static SocketDescriptor CreateAndBind(const std::string& ip, int port); // Implements StreamListenSocket::Accept. virtual void Accept() OVERRIDE; diff --git a/net/base/tcp_listen_socket_unittest.cc b/net/base/tcp_listen_socket_unittest.cc index 61b259c..677a804 100644 --- a/net/base/tcp_listen_socket_unittest.cc +++ b/net/base/tcp_listen_socket_unittest.cc @@ -48,7 +48,7 @@ void TCPListenSocketTester::SetUp() { // verify the connect/accept and setup test_socket_ test_socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - ASSERT_NE(INVALID_SOCKET, test_socket_); + ASSERT_NE(StreamListenSocket::kInvalidSocket, test_socket_); struct sockaddr_in client; client.sin_family = AF_INET; client.sin_addr.s_addr = inet_addr(kLoopback); @@ -56,7 +56,7 @@ void TCPListenSocketTester::SetUp() { int ret = HANDLE_EINTR( connect(test_socket_, reinterpret_cast<sockaddr*>(&client), sizeof(client))); - ASSERT_NE(ret, SOCKET_ERROR); + ASSERT_NE(ret, StreamListenSocket::kSocketError); NextAction(); ASSERT_EQ(ACTION_ACCEPT, last_action_.type()); @@ -100,7 +100,7 @@ int TCPListenSocketTester::ClearTestSocket() { int len_ret = 0; do { int len = HANDLE_EINTR(recv(test_socket_, buf, kReadBufSize, 0)); - if (len == SOCKET_ERROR || len == 0) { + if (len == StreamListenSocket::kSocketError || len == 0) { break; } else { len_ret += len; @@ -210,10 +210,11 @@ void TCPListenSocketTester::TestServerSendMultiple() { } } -bool TCPListenSocketTester::Send(SOCKET sock, const std::string& str) { +bool TCPListenSocketTester::Send(SocketDescriptor sock, + const std::string& str) { int len = static_cast<int>(str.length()); int send_len = HANDLE_EINTR(send(sock, str.data(), len, 0)); - if (send_len == SOCKET_ERROR) { + if (send_len == StreamListenSocket::kSocketError) { LOG(ERROR) << "send failed: " << errno; return false; } else if (send_len != len) { diff --git a/net/base/tcp_listen_socket_unittest.h b/net/base/tcp_listen_socket_unittest.h index 71870f7..e1069b6 100644 --- a/net/base/tcp_listen_socket_unittest.h +++ b/net/base/tcp_listen_socket_unittest.h @@ -29,13 +29,6 @@ #include "net/base/winsock_init.h" #include "testing/gtest/include/gtest/gtest.h" -#if defined(OS_POSIX) -// Used same name as in Windows to avoid #ifdef where referenced -#define SOCKET int -const int INVALID_SOCKET = -1; -const int SOCKET_ERROR = -1; -#endif - namespace net { enum ActionType { @@ -95,7 +88,7 @@ class TCPListenSocketTester : // verify multiple sends and reads from server to client. void TestServerSendMultiple(); - virtual bool Send(SOCKET sock, const std::string& str); + virtual bool Send(SocketDescriptor sock, const std::string& str); // StreamListenSocket::Delegate: virtual void DidAccept(StreamListenSocket* server, @@ -110,7 +103,7 @@ class TCPListenSocketTester : StreamListenSocket* connection_; TCPListenSocketTestAction last_action_; - SOCKET test_socket_; + SocketDescriptor test_socket_; static const int kTestPort; base::Lock lock_; // protects |queue_| and wraps |cv_| diff --git a/net/base/unix_domain_socket_posix.cc b/net/base/unix_domain_socket_posix.cc new file mode 100644 index 0000000..d852ed8 --- /dev/null +++ b/net/base/unix_domain_socket_posix.cc @@ -0,0 +1,186 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/base/unix_domain_socket_posix.h" + +#include <cstring> +#include <string> + +#include <errno.h> +#include <sys/socket.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <sys/un.h> +#include <unistd.h> + +#include "base/bind.h" +#include "base/callback.h" +#include "base/eintr_wrapper.h" +#include "base/threading/platform_thread.h" +#include "build/build_config.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" + +namespace net { + +namespace { + +bool NoAuthenticationCallback(uid_t, gid_t) { + return true; +} + +bool GetPeerIds(int socket, uid_t* user_id, gid_t* group_id) { +#if defined(OS_LINUX) || defined(OS_ANDROID) + struct ucred user_cred; + socklen_t len = sizeof(user_cred); + if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) == -1) + return false; + *user_id = user_cred.uid; + *group_id = user_cred.gid; +#else + if (getpeereid(socket, user_id, group_id) == -1) + return false; +#endif + return true; +} + +} // namespace + +// static +UnixDomainSocket::AuthCallback NoAuthentication() { + return base::Bind(NoAuthenticationCallback); +} + +// static +UnixDomainSocket* UnixDomainSocket::CreateAndListenInternal( + const std::string& path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback, + bool use_abstract_namespace) { + SocketDescriptor s = CreateAndBind(path, use_abstract_namespace); + if (s == kInvalidSocket) + return NULL; + UnixDomainSocket* sock = new UnixDomainSocket(s, del, auth_callback); + sock->Listen(); + return sock; +} + +// static +scoped_refptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen( + const std::string& path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback) { + return CreateAndListenInternal(path, del, auth_callback, false); +} + +#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) +// static +scoped_refptr<UnixDomainSocket> +UnixDomainSocket::CreateAndListenWithAbstractNamespace( + const std::string& path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback) { + return make_scoped_refptr( + CreateAndListenInternal(path, del, auth_callback, true)); +} +#endif + +UnixDomainSocket::UnixDomainSocket( + SocketDescriptor s, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback) + : StreamListenSocket(s, del), + auth_callback_(auth_callback) {} + +UnixDomainSocket::~UnixDomainSocket() {} + +// static +SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path, + bool use_abstract_namespace) { + sockaddr_un addr; + static const size_t kPathMax = sizeof(addr.sun_path); + if (use_abstract_namespace + path.size() + 1 /* '\0' */ > kPathMax) + return kInvalidSocket; + const SocketDescriptor s = socket(PF_UNIX, SOCK_STREAM, 0); + if (s == kInvalidSocket) + return kInvalidSocket; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + socklen_t addr_len; + if (use_abstract_namespace) { + // Convert the path given into abstract socket name. It must start with + // the '\0' character, so we are adding it. |addr_len| must specify the + // length of the structure exactly, as potentially the socket name may + // have '\0' characters embedded (although we don't support this). + // Note that addr.sun_path is already zero initialized. + memcpy(addr.sun_path + 1, path.c_str(), path.size()); + addr_len = path.size() + offsetof(struct sockaddr_un, sun_path) + 1; + } else { + memcpy(addr.sun_path, path.c_str(), path.size()); + addr_len = sizeof(sockaddr_un); + } + if (bind(s, reinterpret_cast<sockaddr*>(&addr), addr_len)) { + LOG(ERROR) << "Could not bind unix domain socket to " << path; + if (use_abstract_namespace) + LOG(ERROR) << " (with abstract namespace enabled)"; + if (HANDLE_EINTR(close(s)) < 0) + LOG(ERROR) << "close() error"; + return kInvalidSocket; + } + return s; +} + +void UnixDomainSocket::Accept() { + SocketDescriptor conn = StreamListenSocket::AcceptSocket(); + if (conn == kInvalidSocket) + return; + uid_t user_id; + gid_t group_id; + if (!GetPeerIds(conn, &user_id, &group_id) || + !auth_callback_.Run(user_id, group_id)) { + if (HANDLE_EINTR(close(conn)) < 0) + LOG(ERROR) << "close() error"; + return; + } + scoped_refptr<UnixDomainSocket> sock( + new UnixDomainSocket(conn, socket_delegate_, auth_callback_)); + // It's up to the delegate to AddRef if it wants to keep it around. + sock->WatchSocket(WAITING_READ); + socket_delegate_->DidAccept(this, sock); +} + +UnixDomainSocketFactory::UnixDomainSocketFactory( + const std::string& path, + const UnixDomainSocket::AuthCallback& auth_callback) + : path_(path), + auth_callback_(auth_callback) {} + +UnixDomainSocketFactory::~UnixDomainSocketFactory() {} + +scoped_refptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen( + StreamListenSocket::Delegate* delegate) const { + return UnixDomainSocket::CreateAndListen(path_, delegate, auth_callback_); +} + +#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) + +UnixDomainSocketWithAbstractNamespaceFactory:: +UnixDomainSocketWithAbstractNamespaceFactory( + const std::string& path, + const UnixDomainSocket::AuthCallback& auth_callback) + : UnixDomainSocketFactory(path, auth_callback) {} + +UnixDomainSocketWithAbstractNamespaceFactory:: +~UnixDomainSocketWithAbstractNamespaceFactory() {} + +scoped_refptr<StreamListenSocket> +UnixDomainSocketWithAbstractNamespaceFactory::CreateAndListen( + StreamListenSocket::Delegate* delegate) const { + return UnixDomainSocket::CreateAndListenWithAbstractNamespace( + path_, delegate, auth_callback_); +} + +#endif + +} // namespace net diff --git a/net/base/unix_domain_socket_posix.h b/net/base/unix_domain_socket_posix.h new file mode 100644 index 0000000..7730900 --- /dev/null +++ b/net/base/unix_domain_socket_posix.h @@ -0,0 +1,121 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_BASE_UNIX_DOMAIN_SOCKET_POSIX_H_ +#define NET_BASE_UNIX_DOMAIN_SOCKET_POSIX_H_ +#pragma once + +#include <string> + +#include "base/basictypes.h" +#include "base/callback_forward.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "build/build_config.h" +#include "net/base/net_export.h" +#include "net/base/stream_listen_socket.h" + +#if defined(OS_ANDROID) || defined(OS_LINUX) +// Feature only supported on Linux currently. This lets the Unix Domain Socket +// not be backed by the file system. +#define SOCKET_ABSTRACT_NAMESPACE_SUPPORTED +#endif + +namespace net { + +// Unix Domain Socket Implementation. Supports abstract namespaces on Linux. +class NET_EXPORT UnixDomainSocket : public StreamListenSocket { + public: + // Callback that returns whether the already connected client, identified by + // its process |user_id| and |group_id|, is allowed to keep the connection + // open. Note that the socket is closed immediately in case the callback + // returns false. + typedef base::Callback<bool (uid_t user_id, gid_t group_id)> AuthCallback; + + // Returns an authentication callback that always grants access for + // convenience in case you don't want to use authentication. + static AuthCallback NoAuthentication(); + + // Note that the returned UnixDomainSocket instance does not take ownership of + // |del|. + static scoped_refptr<UnixDomainSocket> CreateAndListen( + const std::string& path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback); + +#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) + // Same as above except that the created socket uses the abstract namespace + // which is a Linux-only feature. + static scoped_refptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace( + const std::string& path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback); +#endif + + private: + UnixDomainSocket(SocketDescriptor s, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback); + virtual ~UnixDomainSocket(); + + static UnixDomainSocket* CreateAndListenInternal( + const std::string& path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback, + bool use_abstract_namespace); + + static SocketDescriptor CreateAndBind(const std::string& path, + bool use_abstract_namespace); + + // StreamListenSocket: + virtual void Accept() OVERRIDE; + + AuthCallback auth_callback_; + + DISALLOW_COPY_AND_ASSIGN(UnixDomainSocket); +}; + +// Factory that can be used to instantiate UnixDomainSocket. +class NET_EXPORT UnixDomainSocketFactory : public StreamListenSocketFactory { + public: + // Note that this class does not take ownership of the provided delegate. + UnixDomainSocketFactory(const std::string& path, + const UnixDomainSocket::AuthCallback& auth_callback); + virtual ~UnixDomainSocketFactory(); + + // StreamListenSocketFactory: + virtual scoped_refptr<StreamListenSocket> CreateAndListen( + StreamListenSocket::Delegate* delegate) const OVERRIDE; + + protected: + const std::string path_; + const UnixDomainSocket::AuthCallback auth_callback_; + + private: + DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketFactory); +}; + +#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) +// Use this factory to instantiate UnixDomainSocket using the abstract +// namespace feature (only supported on Linux). +class NET_EXPORT UnixDomainSocketWithAbstractNamespaceFactory + : public UnixDomainSocketFactory { + public: + UnixDomainSocketWithAbstractNamespaceFactory( + const std::string& path, + const UnixDomainSocket::AuthCallback& auth_callback); + virtual ~UnixDomainSocketWithAbstractNamespaceFactory(); + + // UnixDomainSocketFactory: + virtual scoped_refptr<StreamListenSocket> CreateAndListen( + StreamListenSocket::Delegate* delegate) const OVERRIDE; + + private: + DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketWithAbstractNamespaceFactory); +}; +#endif + +} // namespace net + +#endif // NET_BASE_UNIX_DOMAIN_SOCKET_POSIX_H_ diff --git a/net/base/unix_domain_socket_posix_unittest.cc b/net/base/unix_domain_socket_posix_unittest.cc new file mode 100644 index 0000000..e72ac59 --- /dev/null +++ b/net/base/unix_domain_socket_posix_unittest.cc @@ -0,0 +1,313 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <errno.h> +#include <fcntl.h> +#include <poll.h> +#include <sys/socket.h> +#include <sys/stat.h> +#include <sys/time.h> +#include <sys/types.h> +#include <sys/un.h> +#include <unistd.h> + +#include <cstring> +#include <queue> +#include <string> + +#include "base/bind.h" +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/eintr_wrapper.h" +#include "base/file_path.h" +#include "base/file_util.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop.h" +#include "base/synchronization/condition_variable.h" +#include "base/synchronization/lock.h" +#include "base/threading/platform_thread.h" +#include "base/threading/thread.h" +#include "net/base/unix_domain_socket_posix.h" +#include "testing/gtest/include/gtest/gtest.h" + +using std::queue; +using std::string; + +namespace net { +namespace { + +const char kSocketFilename[] = "unix_domain_socket_for_testing"; +const char kInvalidSocketPath[] = "/invalid/path"; +const char kMsg[] = "hello"; + +enum EventType { + EVENT_ACCEPT, + EVENT_AUTH_DENIED, + EVENT_AUTH_GRANTED, + EVENT_CLOSE, + EVENT_LISTEN, + EVENT_READ, +}; + +string MakeSocketPath() { + FilePath temp_dir; + file_util::GetTempDir(&temp_dir); + return temp_dir.Append(kSocketFilename).value(); +} + +class EventManager : public base::RefCounted<EventManager> { + public: + EventManager() : condition_(&mutex_) {} + + bool HasPendingEvent() { + base::AutoLock lock(mutex_); + return !events_.empty(); + } + + void Notify(EventType event) { + base::AutoLock lock(mutex_); + events_.push(event); + condition_.Broadcast(); + } + + EventType WaitForEvent() { + base::AutoLock lock(mutex_); + while (events_.empty()) + condition_.Wait(); + EventType event = events_.front(); + events_.pop(); + return event; + } + + private: + friend class base::RefCounted<EventManager>; + virtual ~EventManager() {} + + queue<EventType> events_; + base::Lock mutex_; + base::ConditionVariable condition_; +}; + +class TestListenSocketDelegate : public StreamListenSocket::Delegate { + public: + explicit TestListenSocketDelegate( + const scoped_refptr<EventManager>& event_manager) + : event_manager_(event_manager) {} + + virtual void DidAccept(StreamListenSocket* server, + StreamListenSocket* connection) OVERRIDE { + LOG(ERROR) << __PRETTY_FUNCTION__; + connection_ = connection; + Notify(EVENT_ACCEPT); + } + + virtual void DidRead(StreamListenSocket* connection, + const char* data, + int len) OVERRIDE { + { + base::AutoLock lock(mutex_); + DCHECK(len); + data_.assign(data, len - 1); + } + Notify(EVENT_READ); + } + + virtual void DidClose(StreamListenSocket* sock) OVERRIDE { + Notify(EVENT_CLOSE); + } + + void OnListenCompleted() { + Notify(EVENT_LISTEN); + } + + string ReceivedData() { + base::AutoLock lock(mutex_); + return data_; + } + + private: + void Notify(EventType event) { + event_manager_->Notify(event); + } + + const scoped_refptr<EventManager> event_manager_; + scoped_refptr<StreamListenSocket> connection_; + base::Lock mutex_; + string data_; +}; + +bool UserCanConnectCallback( + bool allow_user, const scoped_refptr<EventManager>& event_manager, + uid_t, gid_t) { + event_manager->Notify( + allow_user ? EVENT_AUTH_GRANTED : EVENT_AUTH_DENIED); + return allow_user; +} + +class UnixDomainSocketTestHelper : public testing::Test { + public: + void CreateAndListen() { + socket_ = UnixDomainSocket::CreateAndListen( + file_path_.value(), socket_delegate_.get(), MakeAuthCallback()); + socket_delegate_->OnListenCompleted(); + } + + protected: + UnixDomainSocketTestHelper(const string& path, bool allow_user) + : file_path_(path), + allow_user_(allow_user) {} + + virtual void SetUp() OVERRIDE { + event_manager_ = new EventManager(); + socket_delegate_.reset(new TestListenSocketDelegate(event_manager_)); + DeleteSocketFile(); + } + + virtual void TearDown() OVERRIDE { + DeleteSocketFile(); + socket_ = NULL; + socket_delegate_.reset(); + event_manager_ = NULL; + } + + UnixDomainSocket::AuthCallback MakeAuthCallback() { + return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_); + } + + void DeleteSocketFile() { + ASSERT_FALSE(file_path_.empty()); + file_util::Delete(file_path_, false /* not recursive */); + } + + SocketDescriptor CreateClientSocket() { + const SocketDescriptor sock = socket(PF_UNIX, SOCK_STREAM, 0); + if (sock < 0) { + LOG(ERROR) << "socket() error"; + return StreamListenSocket::kInvalidSocket; + } + sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + socklen_t addr_len; + strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path)); + addr_len = sizeof(sockaddr_un); + if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) { + LOG(ERROR) << "connect() error"; + return StreamListenSocket::kInvalidSocket; + } + return sock; + } + + scoped_ptr<base::Thread> CreateAndRunServerThread() { + base::Thread::Options options; + options.message_loop_type = MessageLoop::TYPE_IO; + scoped_ptr<base::Thread> thread(new base::Thread("socketio_test")); + thread->StartWithOptions(options); + thread->message_loop()->PostTask( + FROM_HERE, + base::Bind(&UnixDomainSocketTestHelper::CreateAndListen, + base::Unretained(this))); + return thread.Pass(); + } + + const FilePath file_path_; + const bool allow_user_; + scoped_refptr<EventManager> event_manager_; + scoped_ptr<TestListenSocketDelegate> socket_delegate_; + scoped_refptr<UnixDomainSocket> socket_; +}; + +class UnixDomainSocketTest : public UnixDomainSocketTestHelper { + protected: + UnixDomainSocketTest() + : UnixDomainSocketTestHelper(MakeSocketPath(), true /* allow user */) {} +}; + +class UnixDomainSocketTestWithInvalidPath : public UnixDomainSocketTestHelper { + protected: + UnixDomainSocketTestWithInvalidPath() + : UnixDomainSocketTestHelper(kInvalidSocketPath, true) {} +}; + +class UnixDomainSocketTestWithForbiddenUser + : public UnixDomainSocketTestHelper { + protected: + UnixDomainSocketTestWithForbiddenUser() + : UnixDomainSocketTestHelper(MakeSocketPath(), false /* forbid user */) {} +}; + +TEST_F(UnixDomainSocketTest, CreateAndListen) { + CreateAndListen(); + EXPECT_FALSE(socket_.get() == NULL); +} + +TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) { + CreateAndListen(); + EXPECT_TRUE(socket_.get() == NULL); +} + +#ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED +// Test with an invalid path to make sure that the socket is not backed by a +// file. +TEST_F(UnixDomainSocketTestWithInvalidPath, + CreateAndListenWithAbstractNamespace) { + socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace( + file_path_.value(), socket_delegate_.get(), MakeAuthCallback()); + EXPECT_FALSE(socket_.get() == NULL); +} +#endif + +TEST_F(UnixDomainSocketTest, TestWithClient) { + const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread(); + EventType event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_LISTEN, event); + + // Create the client socket. + const SocketDescriptor sock = CreateClientSocket(); + ASSERT_NE(StreamListenSocket::kInvalidSocket, sock); + event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_AUTH_GRANTED, event); + event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_ACCEPT, event); + + // Send a message from the client to the server. + ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0)); + ASSERT_NE(-1, ret); + ASSERT_EQ(sizeof(kMsg), static_cast<size_t>(ret)); + event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_READ, event); + ASSERT_EQ(kMsg, socket_delegate_->ReceivedData()); + + // Close the client socket. + ret = HANDLE_EINTR(close(sock)); + event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_CLOSE, event); +} + +TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) { + const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread(); + EventType event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_LISTEN, event); + const SocketDescriptor sock = CreateClientSocket(); + ASSERT_NE(StreamListenSocket::kInvalidSocket, sock); + + event = event_manager_->WaitForEvent(); + ASSERT_EQ(EVENT_AUTH_DENIED, event); + + // Wait until the file descriptor is closed by the server. + struct pollfd poll_fd; + poll_fd.fd = sock; + poll_fd.events = POLLIN; + poll(&poll_fd, 1, -1 /* rely on GTest for timeout handling */); + + // Send() must fail. + ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0)); + ASSERT_EQ(-1, ret); + ASSERT_EQ(EPIPE, errno); + ASSERT_FALSE(event_manager_->HasPendingEvent()); +} + +} // namespace +} // namespace net diff --git a/net/net.gyp b/net/net.gyp index d1a3b46..a2d1d55 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -261,6 +261,8 @@ 'base/transport_security_state.cc', 'base/transport_security_state.h', 'base/transport_security_state_static.h', + 'base/unix_domain_socket_posix.cc', + 'base/unix_domain_socket_posix.h', 'base/upload_data.cc', 'base/upload_data.h', 'base/upload_data_stream.cc', @@ -1091,8 +1093,9 @@ 'base/test_certificate_data.h', 'base/test_completion_callback_unittest.cc', 'base/transport_security_state_unittest.cc', - 'base/upload_data_unittest.cc', + 'base/unix_domain_socket_posix_unittest.cc', 'base/upload_data_stream_unittest.cc', + 'base/upload_data_unittest.cc', 'base/x509_certificate_unittest.cc', 'base/x509_cert_types_unittest.cc', 'base/x509_util_nss_unittest.cc', diff --git a/net/tools/fetch/http_listen_socket.cc b/net/tools/fetch/http_listen_socket.cc index dbb34e6..ddad463 100644 --- a/net/tools/fetch/http_listen_socket.cc +++ b/net/tools/fetch/http_listen_socket.cc @@ -13,7 +13,7 @@ #include "net/tools/fetch/http_server_request_info.h" #include "net/tools/fetch/http_server_response_info.h" -HttpListenSocket::HttpListenSocket(SOCKET s, +HttpListenSocket::HttpListenSocket(SocketDescriptor s, HttpListenSocket::Delegate* delegate) : ALLOW_THIS_IN_INITIALIZER_LIST(net::TCPListenSocket(s, this)), delegate_(delegate) { @@ -23,7 +23,7 @@ HttpListenSocket::~HttpListenSocket() { } void HttpListenSocket::Accept() { - SOCKET conn = net::TCPListenSocket::AcceptSocket(); + SocketDescriptor conn = net::TCPListenSocket::AcceptSocket(); DCHECK_NE(conn, net::TCPListenSocket::kInvalidSocket); if (conn == net::TCPListenSocket::kInvalidSocket) { // TODO @@ -40,7 +40,7 @@ scoped_refptr<HttpListenSocket> HttpListenSocket::CreateAndListen( const std::string& ip, int port, HttpListenSocket::Delegate* delegate) { - SOCKET s = net::TCPListenSocket::CreateAndBind(ip, port); + SocketDescriptor s = net::TCPListenSocket::CreateAndBind(ip, port); if (s == net::TCPListenSocket::kInvalidSocket) { // TODO (ibrar): error handling. } else { diff --git a/net/tools/fetch/http_listen_socket.h b/net/tools/fetch/http_listen_socket.h index 4c3b27b..5eaf7a9 100644 --- a/net/tools/fetch/http_listen_socket.h +++ b/net/tools/fetch/http_listen_socket.h @@ -50,7 +50,7 @@ class HttpListenSocket : public net::TCPListenSocket, static const int kReadBufSize = 16 * 1024; // Must run in the IO thread. - HttpListenSocket(SOCKET s, HttpListenSocket::Delegate* del); + HttpListenSocket(SocketDescriptor s, HttpListenSocket::Delegate* del); virtual ~HttpListenSocket(); // Expects the raw data to be stored in recv_data_. If parsing is successful, |