diff options
author | alexeypa@chromium.org <alexeypa@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-03-13 01:54:38 +0000 |
---|---|---|
committer | alexeypa@chromium.org <alexeypa@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-03-13 01:54:38 +0000 |
commit | 0a29063eef78157e35da84f60c7c12043f40f29d (patch) | |
tree | c1d9969a4bce7580aa9b10f0166dd5fdc27fcedb /remoting | |
parent | e6510e78aa52fa6e92c47840d98c23000b31f2ca (diff) | |
download | chromium_src-0a29063eef78157e35da84f60c7c12043f40f29d.zip chromium_src-0a29063eef78157e35da84f60c7c12043f40f29d.tar.gz chromium_src-0a29063eef78157e35da84f60c7c12043f40f29d.tar.bz2 |
Moved the methods for mapping a session ID to the connected RDP client's address (and vice versa) to WtsTerminalMonitor.
This makes GetEndpointForSessionId() and GetSessionIdForEndpoint() reusable.
BUG=137696
TEST=remoting_unittests.RdpClientTest.*
Review URL: https://chromiumcodereview.appspot.com/12632011
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@187743 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'remoting')
-rw-r--r-- | remoting/host/win/host_service.cc | 136 | ||||
-rw-r--r-- | remoting/host/win/host_service.h | 23 | ||||
-rw-r--r-- | remoting/host/win/rdp_client_unittest.cc | 20 | ||||
-rw-r--r-- | remoting/host/win/wts_terminal_monitor.cc | 176 | ||||
-rw-r--r-- | remoting/host/win/wts_terminal_monitor.h | 26 | ||||
-rw-r--r-- | remoting/remoting.gyp | 6 |
6 files changed, 220 insertions, 167 deletions
diff --git a/remoting/host/win/host_service.cc b/remoting/host/win/host_service.cc index b760b0f..60a1796 100644 --- a/remoting/host/win/host_service.cc +++ b/remoting/host/win/host_service.cc @@ -18,7 +18,6 @@ #include "base/files/file_path.h" #include "base/message_loop.h" #include "base/run_loop.h" -#include "base/scoped_native_library.h" #include "base/single_thread_task_runner.h" #include "base/threading/thread.h" #include "base/utf_string_conversions.h" @@ -47,13 +46,6 @@ namespace remoting { namespace { -// Used to query the endpoint of an attached RDP client. -const WINSTATIONINFOCLASS kWinStationRemoteAddress = - static_cast<WINSTATIONINFOCLASS>(29); - -// Session id that does not represent any session. -const uint32 kInvalidSessionId = 0xffffffffu; - const char kIoThreadName[] = "I/O thread"; // A window class for the session change notifications window. @@ -213,7 +205,6 @@ void HostService::RemoveWtsTerminalObserver(WtsTerminalObserver* observer) { } HostService::HostService() : - win_station_query_information_(NULL), run_routine_(&HostService::RunAsService), service_status_handle_(0), stopped_event_(true, false) { @@ -222,133 +213,6 @@ HostService::HostService() : HostService::~HostService() { } -bool HostService::GetEndpointForSessionId(uint32 session_id, - net::IPEndPoint* endpoint) { - DCHECK(main_task_runner_->BelongsToCurrentThread()); - - // Fast path for the case when |session_id| is currently attached to - // the physical console. - if (session_id == WTSGetActiveConsoleSessionId()) { - *endpoint = net::IPEndPoint(); - return true; - } - - // Get the pointer to winsta!WinStationQueryInformationW(). - if (!LoadWinStationLibrary()) - return false; - - // WinStationRemoteAddress information class returns the following structure. - // Note that its layout is different from sockaddr_in/sockaddr_in6. For - // instance both |ipv4| and |ipv6| structures are 4 byte aligned so there is - // additional 2 byte padding after |sin_family|. - struct RemoteAddress { - unsigned short sin_family; - union { - struct { - USHORT sin_port; - ULONG in_addr; - UCHAR sin_zero[8]; - } ipv4; - struct { - USHORT sin6_port; - ULONG sin6_flowinfo; - USHORT sin6_addr[8]; - ULONG sin6_scope_id; - } ipv6; - }; - }; - - RemoteAddress address; - ULONG length; - if (!win_station_query_information_(WTS_CURRENT_SERVER_HANDLE, - session_id, - kWinStationRemoteAddress, - &address, - sizeof(address), - &length)) { - // WinStationQueryInformationW() fails if no RDP client is attached to - // |session_id|. - return false; - } - - // Convert the RemoteAddress structure into sockaddr_in/sockaddr_in6. - switch (address.sin_family) { - case AF_INET: { - sockaddr_in ipv4 = { 0 }; - ipv4.sin_family = AF_INET; - ipv4.sin_port = address.ipv4.sin_port; - ipv4.sin_addr.S_un.S_addr = address.ipv4.in_addr; - return endpoint->FromSockAddr( - reinterpret_cast<struct sockaddr*>(&ipv4), sizeof(ipv4)); - } - - case AF_INET6: { - sockaddr_in6 ipv6 = { 0 }; - ipv6.sin6_family = AF_INET6; - ipv6.sin6_port = address.ipv6.sin6_port; - ipv6.sin6_flowinfo = address.ipv6.sin6_flowinfo; - memcpy(&ipv6.sin6_addr, address.ipv6.sin6_addr, sizeof(ipv6.sin6_addr)); - ipv6.sin6_scope_id = address.ipv6.sin6_scope_id; - return endpoint->FromSockAddr( - reinterpret_cast<struct sockaddr*>(&ipv6), sizeof(ipv6)); - } - - default: - return false; - } -} - -uint32 HostService::GetSessionIdForEndpoint( - const net::IPEndPoint& client_endpoint) { - DCHECK(main_task_runner_->BelongsToCurrentThread()); - - // Use the fast path if the caller wants to get id of the session attached to - // the physical console. - if (client_endpoint == net::IPEndPoint()) - return WTSGetActiveConsoleSessionId(); - - // Get the pointer to winsta!WinStationQueryInformationW(). - if (!LoadWinStationLibrary()) - return kInvalidSessionId; - - // Enumerate all sessions and try to match the client endpoint. - WTS_SESSION_INFO* session_info; - DWORD session_info_count; - if (!WTSEnumerateSessions(WTS_CURRENT_SERVER_HANDLE, 0, 1, &session_info, - &session_info_count)) { - LOG_GETLASTERROR(ERROR) << "Failed to enumerate all sessions"; - return kInvalidSessionId; - } - for (DWORD i = 0; i < session_info_count; ++i) { - net::IPEndPoint endpoint; - if (GetEndpointForSessionId(session_info[i].SessionId, &endpoint) && - endpoint == client_endpoint) { - WTSFreeMemory(session_info); - return session_info[i].SessionId; - } - } - - // |client_endpoint| is not associated with any session. - WTSFreeMemory(session_info); - return kInvalidSessionId; -} - -bool HostService::LoadWinStationLibrary() { - if (!winsta_) { - base::FilePath winsta_path(base::GetNativeLibraryName( - UTF8ToUTF16("winsta"))); - winsta_.reset(new base::ScopedNativeLibrary(winsta_path)); - - if (winsta_->is_valid()) { - win_station_query_information_ = - static_cast<PWINSTATIONQUERYINFORMATIONW>( - winsta_->GetFunctionPointer("WinStationQueryInformationW")); - } - } - - return win_station_query_information_ != NULL; -} - void HostService::OnSessionChange(uint32 event, uint32 session_id) { DCHECK(main_task_runner_->BelongsToCurrentThread()); DCHECK_NE(session_id, kInvalidSessionId); diff --git a/remoting/host/win/host_service.h b/remoting/host/win/host_service.h index 7c0b05d..942f3fa 100644 --- a/remoting/host/win/host_service.h +++ b/remoting/host/win/host_service.h @@ -6,7 +6,6 @@ #define REMOTING_HOST_WIN_HOST_SERVICE_H_ #include <windows.h> -#include <winternl.h> #include <list> @@ -19,7 +18,6 @@ class CommandLine; namespace base { -class ScopedNativeLibrary; class SingleThreadTaskRunner; } // namespace base @@ -49,21 +47,6 @@ class HostService : public WtsTerminalMonitor { HostService(); ~HostService(); - // Sets |*endpoint| to the endpoint of the client attached to |session_id|. - // If |session_id| is attached to the physical console net::IPEndPoint() is - // used. Returns false if the endpoint cannot be queried (if there is no - // client attached to |session_id| for instance). - bool GetEndpointForSessionId(uint32 session_id, net::IPEndPoint* endpoint); - - // Returns id of the session that |client_endpoint| is attached. - // |kInvalidSessionId| is returned if none of the sessions is currently - // attahced to |client_endpoint|. - uint32 GetSessionIdForEndpoint(const net::IPEndPoint& client_endpoint); - - // Gets the pointer to winsta!WinStationQueryInformationW(). Returns false if - // en error occurs. - bool LoadWinStationLibrary(); - // Notifies the service of changes in session state. void OnSessionChange(uint32 event, uint32 session_id); @@ -118,12 +101,6 @@ class HostService : public WtsTerminalMonitor { // The list of observers receiving session notifications. std::list<RegisteredObserver> observers_; - // Handle of dynamically loaded winsta.dll. - scoped_ptr<base::ScopedNativeLibrary> winsta_; - - // Points to winsta!WinStationQueryInformationW(). - PWINSTATIONQUERYINFORMATIONW win_station_query_information_; - scoped_ptr<Stoppable> child_; // Service message loop. diff --git a/remoting/host/win/rdp_client_unittest.cc b/remoting/host/win/rdp_client_unittest.cc index f906829..229bc06c 100644 --- a/remoting/host/win/rdp_client_unittest.cc +++ b/remoting/host/win/rdp_client_unittest.cc @@ -7,12 +7,15 @@ #include <atlhost.h> #include "base/basictypes.h" +#include "base/bind.h" +#include "base/bind_helpers.h" #include "base/message_loop.h" #include "base/run_loop.h" #include "base/win/scoped_com_initializer.h" #include "net/base/ip_endpoint.h" #include "remoting/base/auto_thread_task_runner.h" #include "remoting/host/win/rdp_client.h" +#include "remoting/host/win/wts_terminal_monitor.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gmock_mutant.h" #include "testing/gtest/include/gtest/gtest.h" @@ -79,6 +82,9 @@ class RdpClientTest : public testing::Test { virtual void SetUp() OVERRIDE; virtual void TearDown() OVERRIDE; + // Caaled when an RDP connection is established. + void OnRdpConnected(const net::IPEndPoint& endpoint); + // Tears down |rdp_client_|. void CloseRdpClient(); @@ -119,6 +125,18 @@ void RdpClientTest::TearDown() { module_.reset(); } +void RdpClientTest::OnRdpConnected(const net::IPEndPoint& endpoint) { + uint32 session_id = WtsTerminalMonitor::GetSessionIdForEndpoint(endpoint); + + net::IPEndPoint session_endpoint; + EXPECT_TRUE(WtsTerminalMonitor::GetEndpointForSessionId(session_id, + &session_endpoint)); + EXPECT_EQ(endpoint, session_endpoint); + + message_loop_.PostTask(FROM_HERE, base::Bind(&RdpClientTest::CloseRdpClient, + base::Unretained(this))); +} + void RdpClientTest::CloseRdpClient() { EXPECT_TRUE(rdp_client_); @@ -132,7 +150,7 @@ TEST_F(RdpClientTest, Basic) { // and a connection error as a successful outcome. EXPECT_CALL(event_handler_, OnRdpConnected(_)) .Times(AtMost(1)) - .WillOnce(InvokeWithoutArgs(this, &RdpClientTest::CloseRdpClient)); + .WillOnce(Invoke(this, &RdpClientTest::OnRdpConnected)); EXPECT_CALL(event_handler_, OnRdpClosed()) .Times(AtMost(1)) .WillOnce(InvokeWithoutArgs(this, &RdpClientTest::CloseRdpClient)); diff --git a/remoting/host/win/wts_terminal_monitor.cc b/remoting/host/win/wts_terminal_monitor.cc new file mode 100644 index 0000000..0395651 --- /dev/null +++ b/remoting/host/win/wts_terminal_monitor.cc @@ -0,0 +1,176 @@ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "remoting/host/win/wts_terminal_monitor.h" + +#include <windows.h> +#include <winternl.h> +#include <wtsapi32.h> + +#include "base/basictypes.h" +#include "base/files/file_path.h" +#include "base/lazy_instance.h" +#include "base/native_library.h" +#include "base/scoped_native_library.h" +#include "base/utf_string_conversions.h" +#include "net/base/ip_endpoint.h" + +namespace { + +// Used to query the endpoint of an attached RDP client. +const WINSTATIONINFOCLASS kWinStationRemoteAddress = + static_cast<WINSTATIONINFOCLASS>(29); + +// WinStationRemoteAddress information class returns the following structure. +// Note that its layout is different from sockaddr_in/sockaddr_in6. For +// instance both |ipv4| and |ipv6| structures are 4 byte aligned so there is +// additional 2 byte padding after |sin_family|. +struct RemoteAddress { + unsigned short sin_family; + union { + struct { + USHORT sin_port; + ULONG in_addr; + UCHAR sin_zero[8]; + } ipv4; + struct { + USHORT sin6_port; + ULONG sin6_flowinfo; + USHORT sin6_addr[8]; + ULONG sin6_scope_id; + } ipv6; + }; +}; + +// Loads winsta.dll dynamically and resolves the address of +// the winsta!WinStationQueryInformationW() function. +class WinstaLoader { + public: + WinstaLoader(); + ~WinstaLoader(); + + // Returns the address and port of the RDP client attached to |session_id|. + bool GetRemoteAddress(uint32 session_id, RemoteAddress* address); + + private: + // Handle of dynamically loaded winsta.dll. + base::ScopedNativeLibrary winsta_; + + // Points to winsta!WinStationQueryInformationW(). + PWINSTATIONQUERYINFORMATIONW win_station_query_information_; + + DISALLOW_COPY_AND_ASSIGN(WinstaLoader); +}; + +static base::LazyInstance<WinstaLoader> g_winsta = LAZY_INSTANCE_INITIALIZER; + +WinstaLoader::WinstaLoader() : + winsta_(base::FilePath(base::GetNativeLibraryName(UTF8ToUTF16("winsta")))) { + + // Resolve the function pointer. + win_station_query_information_ = + static_cast<PWINSTATIONQUERYINFORMATIONW>( + winsta_.GetFunctionPointer("WinStationQueryInformationW")); +} + +WinstaLoader::~WinstaLoader() { +} + +bool WinstaLoader::GetRemoteAddress(uint32 session_id, RemoteAddress* address) { + ULONG length; + return win_station_query_information_(WTS_CURRENT_SERVER_HANDLE, + session_id, + kWinStationRemoteAddress, + address, + sizeof(*address), + &length) != FALSE; +} + +} // namespace + +namespace remoting { + +// Session id that does not represent any session. +const uint32 kInvalidSessionId = 0xffffffffu; + +WtsTerminalMonitor::~WtsTerminalMonitor() { +} + +// static +bool WtsTerminalMonitor::GetEndpointForSessionId(uint32 session_id, + net::IPEndPoint* endpoint) { + // Fast path for the case when |session_id| is currently attached to + // the physical console. + if (session_id == WTSGetActiveConsoleSessionId()) { + *endpoint = net::IPEndPoint(); + return true; + } + + RemoteAddress address; + // WinStationQueryInformationW() fails if no RDP client is attached to + // |session_id|. + if (!g_winsta.Get().GetRemoteAddress(session_id, &address)) + return false; + + // Convert the RemoteAddress structure into sockaddr_in/sockaddr_in6. + switch (address.sin_family) { + case AF_INET: { + sockaddr_in ipv4 = { 0 }; + ipv4.sin_family = AF_INET; + ipv4.sin_port = address.ipv4.sin_port; + ipv4.sin_addr.S_un.S_addr = address.ipv4.in_addr; + return endpoint->FromSockAddr( + reinterpret_cast<struct sockaddr*>(&ipv4), sizeof(ipv4)); + } + + case AF_INET6: { + sockaddr_in6 ipv6 = { 0 }; + ipv6.sin6_family = AF_INET6; + ipv6.sin6_port = address.ipv6.sin6_port; + ipv6.sin6_flowinfo = address.ipv6.sin6_flowinfo; + memcpy(&ipv6.sin6_addr, address.ipv6.sin6_addr, sizeof(ipv6.sin6_addr)); + ipv6.sin6_scope_id = address.ipv6.sin6_scope_id; + return endpoint->FromSockAddr( + reinterpret_cast<struct sockaddr*>(&ipv6), sizeof(ipv6)); + } + + default: + return false; + } +} + +// static +uint32 WtsTerminalMonitor::GetSessionIdForEndpoint( + const net::IPEndPoint& client_endpoint) { + // Use the fast path if the caller wants to get id of the session attached to + // the physical console. + if (client_endpoint == net::IPEndPoint()) + return WTSGetActiveConsoleSessionId(); + + // Enumerate all sessions and try to match the client endpoint. + WTS_SESSION_INFO* session_info; + DWORD session_info_count; + if (!WTSEnumerateSessions(WTS_CURRENT_SERVER_HANDLE, 0, 1, &session_info, + &session_info_count)) { + LOG_GETLASTERROR(ERROR) << "Failed to enumerate all sessions"; + return kInvalidSessionId; + } + for (DWORD i = 0; i < session_info_count; ++i) { + net::IPEndPoint endpoint; + if (GetEndpointForSessionId(session_info[i].SessionId, &endpoint) && + endpoint == client_endpoint) { + WTSFreeMemory(session_info); + return session_info[i].SessionId; + } + } + + // |client_endpoint| is not associated with any session. + WTSFreeMemory(session_info); + return kInvalidSessionId; +} + +WtsTerminalMonitor::WtsTerminalMonitor() { +} + +} // namespace remoting diff --git a/remoting/host/win/wts_terminal_monitor.h b/remoting/host/win/wts_terminal_monitor.h index f8f8b79..670fba5 100644 --- a/remoting/host/win/wts_terminal_monitor.h +++ b/remoting/host/win/wts_terminal_monitor.h @@ -5,18 +5,22 @@ #ifndef REMOTING_HOST_WIN_WTS_TERMINAL_MONITOR_H_ #define REMOTING_HOST_WIN_WTS_TERMINAL_MONITOR_H_ -#include <windows.h> - #include "base/basictypes.h" -#include "net/base/ip_endpoint.h" + +namespace net { +class IPEndPoint; +} // namespace net namespace remoting { class WtsTerminalObserver; +// Session id that does not represent any session. +extern const uint32 kInvalidSessionId; + class WtsTerminalMonitor { public: - virtual ~WtsTerminalMonitor() {} + virtual ~WtsTerminalMonitor(); // Registers an observer to receive notifications about a particular WTS // terminal. To speficy the physical console the caller should pass @@ -31,8 +35,20 @@ class WtsTerminalMonitor { // Unregisters a previously registered observer. virtual void RemoveWtsTerminalObserver(WtsTerminalObserver* observer) = 0; + // Sets |*endpoint| to the endpoint of the client attached to |session_id|. + // If |session_id| is attached to the physical console net::IPEndPoint() is + // used. Returns false if the endpoint cannot be queried (if there is no + // client attached to |session_id| for instance). + static bool GetEndpointForSessionId(uint32 session_id, + net::IPEndPoint* endpoint); + + // Returns id of the session that |client_endpoint| is attached. + // |kInvalidSessionId| is returned if none of the sessions is currently + // attahced to |client_endpoint|. + static uint32 GetSessionIdForEndpoint(const net::IPEndPoint& client_endpoint); + protected: - WtsTerminalMonitor() {} + WtsTerminalMonitor(); private: DISALLOW_COPY_AND_ASSIGN(WtsTerminalMonitor); diff --git a/remoting/remoting.gyp b/remoting/remoting.gyp index 1f26013..2948047 100644 --- a/remoting/remoting.gyp +++ b/remoting/remoting.gyp @@ -461,6 +461,9 @@ 'host/win/session_event_executor.h', 'host/win/window_station_and_desktop.cc', 'host/win/window_station_and_desktop.h', + 'host/win/wts_terminal_monitor.cc', + 'host/win/wts_terminal_monitor.h', + 'host/win/wts_terminal_observer.h', ], 'conditions': [ ['toolkit_uses_gtk==1', { @@ -1421,8 +1424,6 @@ 'host/win/wts_console_session_process_driver.h', 'host/win/wts_session_process_delegate.cc', 'host/win/wts_session_process_delegate.h', - 'host/win/wts_terminal_monitor.h', - 'host/win/wts_terminal_observer.h', 'host/worker_process_ipc_delegate.h', ], 'msvs_settings': { @@ -2599,6 +2600,7 @@ 'link_settings': { 'libraries': [ '-lrpcrt4.lib', + '-lwtsapi32.lib', ], }, }], |