diff options
-rw-r--r-- | chrome_frame/chrome_frame.gyp | 3 | ||||
-rw-r--r-- | chrome_frame/chrome_tab.cc | 33 | ||||
-rw-r--r-- | chrome_frame/module_utils.cc | 175 | ||||
-rw-r--r-- | chrome_frame/module_utils.h | 80 | ||||
-rw-r--r-- | chrome_frame/test/data/test_dlls/1/TestDll.dll | bin | 0 -> 12800 bytes | |||
-rw-r--r-- | chrome_frame/test/data/test_dlls/2/TestDll.dll | bin | 0 -> 12800 bytes | |||
-rw-r--r-- | chrome_frame/test/data/test_dlls/3/TestDll.dll | bin | 0 -> 12800 bytes | |||
-rw-r--r-- | chrome_frame/test/data/test_dlls/DummyCF/npchrome_frame.dll | bin | 0 -> 12800 bytes | |||
-rw-r--r-- | chrome_frame/test/data/test_dlls/README | 7 | ||||
-rw-r--r-- | chrome_frame/test/data/test_dlls/TestDllNoCF/TestDll.dll | bin | 0 -> 7680 bytes | |||
-rw-r--r-- | chrome_frame/test/module_utils_unittest.cc | 196 |
11 files changed, 489 insertions, 5 deletions
diff --git a/chrome_frame/chrome_frame.gyp b/chrome_frame/chrome_frame.gyp index 9d62c56..d6b0827 100644 --- a/chrome_frame/chrome_frame.gyp +++ b/chrome_frame/chrome_frame.gyp @@ -218,6 +218,7 @@ 'test/chrome_frame_automation_mock.h', 'test/http_server.cc', 'test/http_server.h', + 'test/module_utils_unittest.cc', 'test/proxy_factory_mock.cc', 'test/proxy_factory_mock.h', 'test/run_all_unittests.cc', @@ -644,6 +645,8 @@ 'http_negotiate.h', 'iids.cc', 'in_place_menu.h', + 'module_utils.cc', + 'module_utils.h', 'ole_document_impl.h', 'protocol_sink_wrap.cc', 'protocol_sink_wrap.h', diff --git a/chrome_frame/chrome_tab.cc b/chrome_frame/chrome_tab.cc index d8deec4..08e8a84 100644 --- a/chrome_frame/chrome_tab.cc +++ b/chrome_frame/chrome_tab.cc @@ -13,6 +13,7 @@ #include "base/command_line.h" #include "base/file_util.h" #include "base/file_version_info.h" +#include "base/lock.h" #include "base/logging.h" #include "base/logging_win.h" #include "base/path_service.h" @@ -30,6 +31,7 @@ #include "chrome_frame/chrome_frame_reporting.h" #include "chrome_frame/chrome_launcher.h" #include "chrome_frame/chrome_protocol.h" +#include "chrome_frame/module_utils.h" #include "chrome_frame/resource.h" #include "chrome_frame/utils.h" #include "googleurl/src/url_util.h" @@ -45,7 +47,6 @@ void InitGoogleUrl() { url_util::IsStandard(kDummyUrl, url_parse::MakeRange(0, arraysize(kDummyUrl))); } - } static const wchar_t kBhoRegistryPath[] = @@ -72,6 +73,11 @@ OBJECT_ENTRY_AUTO(__uuidof(ChromeActiveDocument), ChromeActiveDocument) OBJECT_ENTRY_AUTO(__uuidof(ChromeFrame), ChromeFrameActivex) OBJECT_ENTRY_AUTO(__uuidof(ChromeProtocol), ChromeProtocol) + +// See comments in DllGetClassObject. +DllRedirector g_dll_redirector; +Lock g_redirector_lock; + class ChromeTabModule : public AtlPerUserModule<CAtlDllModuleT<ChromeTabModule> > { public: @@ -316,12 +322,29 @@ STDAPI DllCanUnloadNow() { // Returns a class factory to create an object of the requested type STDAPI DllGetClassObject(REFCLSID rclsid, REFIID riid, LPVOID* ppv) { - if (g_patch_helper.InitializeAndPatchProtocolsIfNeeded()) { - // We should only get here once. - UrlMkSetSessionOption(URLMON_OPTION_USERAGENT_REFRESH, NULL, 0, 0); + // On first call, we scan the loaded module list to see if an older version + // of Chrome Frame is already loaded. If it is, then we delegate all calls + // to DllGetClassObject to it. This is to avoid having instances of + // different versions of e.g. the BHO through an upgrade. It also prevents + // us from repeatedly patching. + LPFNGETCLASSOBJECT redir_ptr = NULL; + { + AutoLock lock(g_redirector_lock); + g_dll_redirector.EnsureInitialized(L"npchrome_frame.dll", + CLSID_ChromeActiveDocument); + redir_ptr = g_dll_redirector.get_dll_get_class_object_ptr(); } - return _AtlModule.DllGetClassObject(rclsid, riid, ppv); + if (redir_ptr) { + return redir_ptr(rclsid, riid, ppv); + } else { + if (g_patch_helper.InitializeAndPatchProtocolsIfNeeded()) { + // We should only get here once. + UrlMkSetSessionOption(URLMON_OPTION_USERAGENT_REFRESH, NULL, 0, 0); + } + + return _AtlModule.DllGetClassObject(rclsid, riid, ppv); + } } // DllRegisterServer - Adds entries to the system registry diff --git a/chrome_frame/module_utils.cc b/chrome_frame/module_utils.cc new file mode 100644 index 0000000..77cdd91 --- /dev/null +++ b/chrome_frame/module_utils.cc @@ -0,0 +1,175 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "chrome_frame/module_utils.h" + +#include <atlbase.h> +#include <TlHelp32.h> + +#include "base/scoped_ptr.h" +#include "base/file_version_info.h" +#include "base/logging.h" +#include "base/scoped_handle.h" +#include "base/string_util.h" +#include "base/version.h" + +DllRedirector::DllRedirector() : dcgo_ptr_(NULL), initialized_(false), + module_handle_(NULL) {} + +DllRedirector::~DllRedirector() { + if (module_handle_) { + FreeLibrary(module_handle_); + module_handle_ = NULL; + } +} + +void DllRedirector::EnsureInitialized(const wchar_t* module_name, + REFCLSID clsid) { + if (!initialized_) { + initialized_ = true; + // Also sets module_handle_. + dcgo_ptr_ = GetDllGetClassObjectFromModuleName(module_name, clsid); + } +} + +LPFNGETCLASSOBJECT DllRedirector::get_dll_get_class_object_ptr() const { + DCHECK(initialized_); + return dcgo_ptr_; +} + +LPFNGETCLASSOBJECT DllRedirector::GetDllGetClassObjectFromModuleName( + const wchar_t* module_name, REFCLSID clsid) { + module_handle_ = NULL; + LPFNGETCLASSOBJECT proc_ptr = NULL; + HMODULE module_handle; + if (GetOldestNamedModuleHandle(module_name, clsid, &module_handle)) { + HMODULE this_module = reinterpret_cast<HMODULE>(&__ImageBase); + if (module_handle != this_module) { + proc_ptr = GetDllGetClassObjectPtr(module_handle); + if (proc_ptr) { + // Stash away the module handle in module_handle_ so that it will be + // automatically closed when we get destroyed. GetDllGetClassObjectPtr + // above will have incremented the module's ref count. + module_handle_ = module_handle; + } + } else { + LOG(INFO) << "Module Scan: DllGetClassObject found in current module."; + } + } + + return proc_ptr; +} + +bool DllRedirector::GetOldestNamedModuleHandle(const std::wstring& module_name, + REFCLSID clsid, + HMODULE* oldest_module_handle) { + DCHECK(oldest_module_handle); + + ScopedHandle snapshot(CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, 0)); + if (snapshot == INVALID_HANDLE_VALUE) { + LOG(ERROR) << "Could not create module snapshot!"; + return false; + } + + bool success = false; + PathToHModuleMap map; + + // First get the list of module paths, and save the full path to base address + // mapping. + MODULEENTRY32W module_entry; + module_entry.dwSize = sizeof(module_entry); + BOOL cont = Module32FirstW(snapshot, &module_entry); + while (cont) { + if (!lstrcmpi(module_entry.szModule, module_name.c_str())) { + std::wstring full_path(module_entry.szExePath); + map[full_path] = module_entry.hModule; + } + cont = Module32NextW(snapshot, &module_entry); + } + + // Next, enumerate the map and find the oldest version of the module. + // (check if the map is of size 1 first) + if (!map.empty()) { + if (map.size() == 1) { + *oldest_module_handle = map.begin()->second; + } else { + *oldest_module_handle = GetHandleOfOldestModule(map, clsid); + } + + if (*oldest_module_handle != NULL) { + success = true; + } + } else { + LOG(INFO) << "Module Scan: No modules named " << module_name + << " were found."; + } + + return success; +} + +HMODULE DllRedirector::GetHandleOfOldestModule(const PathToHModuleMap& map, + REFCLSID clsid) { + HMODULE oldest_module = NULL; + scoped_ptr<Version> min_version( + Version::GetVersionFromString("999.999.999.999")); + + PathToHModuleMap::const_iterator map_iter(map.begin()); + for (; map_iter != map.end(); ++map_iter) { + // First check that either we are in the current module or that the DLL + // returns a class factory for our clsid. + bool current_module = + (map_iter->second == reinterpret_cast<HMODULE>(&__ImageBase)); + bool gco_succeeded = false; + if (!current_module) { + LPFNGETCLASSOBJECT dgco_ptr = GetDllGetClassObjectPtr(map_iter->second); + if (dgco_ptr) { + { + CComPtr<IClassFactory> class_factory; + HRESULT hr = dgco_ptr(clsid, IID_IClassFactory, + reinterpret_cast<void**>(&class_factory)); + gco_succeeded = SUCCEEDED(hr) && class_factory != NULL; + } + // Release the module ref count we picked up in GetDllGetClassObjectPtr. + FreeLibrary(map_iter->second); + } + } + + if (current_module || gco_succeeded) { + // Then check that the version is less than we've already found: + scoped_ptr<FileVersionInfo> version_info( + FileVersionInfo::CreateFileVersionInfo(map_iter->first)); + scoped_ptr<Version> version( + Version::GetVersionFromString(version_info->file_version())); + if (version->CompareTo(*min_version.get()) < 0) { + oldest_module = map_iter->second; + min_version.reset(version.release()); + } + } + } + + return oldest_module; +} + +LPFNGETCLASSOBJECT DllRedirector::GetDllGetClassObjectPtr(HMODULE module) { + LPFNGETCLASSOBJECT proc_ptr = NULL; + HMODULE temp_handle = 0; + // Increment the module ref count while we have an pointer to its + // DllGetClassObject function. + if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, + reinterpret_cast<LPCTSTR>(module), + &temp_handle)) { + proc_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>( + GetProcAddress(temp_handle, "DllGetClassObject")); + if (!proc_ptr) { + FreeLibrary(temp_handle); + LOG(ERROR) << "Module Scan: Couldn't get address of " + << "DllGetClassObject: " + << GetLastError(); + } + } else { + LOG(ERROR) << "Module Scan: Could not increment module count: " + << GetLastError(); + } + return proc_ptr; +} diff --git a/chrome_frame/module_utils.h b/chrome_frame/module_utils.h new file mode 100644 index 0000000..e5d8a68 --- /dev/null +++ b/chrome_frame/module_utils.h @@ -0,0 +1,80 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CHROME_FRAME_MODULE_UTILS_H_ +#define CHROME_FRAME_MODULE_UTILS_H_ + +#include <ObjBase.h> +#include <windows.h> + +#include <map> + +// A helper class that will find the named loaded module in the current +// process with the lowest version, increment its ref count and return +// a pointer to its DllGetClassObject() function if it exports one. If +// the oldest named module is the current module, then this class does nothing +// (does not muck with module ref count) and calls to +// get_dll_get_class_object_ptr() will return NULL. +class DllRedirector { + public: + typedef std::map<std::wstring, HMODULE> PathToHModuleMap; + + DllRedirector(); + ~DllRedirector(); + + // Must call this before calling get_dll_get_class_object_ptr(). On first call + // this performs the work of scanning the loaded modules for an old version + // to delegate to. Not thread safe. + void EnsureInitialized(const wchar_t* module_name, REFCLSID clsid); + + LPFNGETCLASSOBJECT get_dll_get_class_object_ptr() const; + + private: + + // Returns the pointer to the named loaded module's DllGetClassObject export + // or NULL if either the pointer could not be found or if the pointer would + // point into the current module. + // Sets module_handle_ and increments the modules reference count. + // + // For sanity's sake, the module must return a non-null class factory for + // the given class id. + LPFNGETCLASSOBJECT GetDllGetClassObjectFromModuleName( + const wchar_t* module_name, REFCLSID clsid); + + // Returns a handle in |module_handle| to the loaded module called + // |module_name| in the current process. If there are multiple modules with + // the same name, it returns the module with the oldest version number in its + // VERSIONINFO block. The version string is expected to be of a form that + // base::Version can parse. + // + // For sanity's sake, when there are multiple instances of the module, + // |product_short_name|, if non-NULL, must match the module's + // ProductShortName value + // + // Returns true if a named module with the given ProductShortName can be + // found, returns false otherwise. Can return the current module handle. + bool GetOldestNamedModuleHandle(const std::wstring& module_name, + REFCLSID clsid, + HMODULE* module_handle); + + // Given a PathToBaseAddressMap, iterates over the module images whose paths + // are the keys and returns the handle to the module with the lowest + // version number in its VERSIONINFO block whose DllGetClassObject returns a + // class factory for the given CLSID. + HMODULE GetHandleOfOldestModule(const PathToHModuleMap& map, REFCLSID clsid); + + private: + // Helper function to return the DllGetClassObject function pointer from + // the given module. On success, the return value is non-null and module + // will have had its reference count incremented. + LPFNGETCLASSOBJECT GetDllGetClassObjectPtr(HMODULE module); + + HMODULE module_handle_; + LPFNGETCLASSOBJECT dcgo_ptr_; + bool initialized_; + + friend class ModuleUtilsTest; +}; + +#endif // CHROME_FRAME_MODULE_UTILS_H_ diff --git a/chrome_frame/test/data/test_dlls/1/TestDll.dll b/chrome_frame/test/data/test_dlls/1/TestDll.dll Binary files differnew file mode 100644 index 0000000..9658de1 --- /dev/null +++ b/chrome_frame/test/data/test_dlls/1/TestDll.dll diff --git a/chrome_frame/test/data/test_dlls/2/TestDll.dll b/chrome_frame/test/data/test_dlls/2/TestDll.dll Binary files differnew file mode 100644 index 0000000..9501cac --- /dev/null +++ b/chrome_frame/test/data/test_dlls/2/TestDll.dll diff --git a/chrome_frame/test/data/test_dlls/3/TestDll.dll b/chrome_frame/test/data/test_dlls/3/TestDll.dll Binary files differnew file mode 100644 index 0000000..007ba36 --- /dev/null +++ b/chrome_frame/test/data/test_dlls/3/TestDll.dll diff --git a/chrome_frame/test/data/test_dlls/DummyCF/npchrome_frame.dll b/chrome_frame/test/data/test_dlls/DummyCF/npchrome_frame.dll Binary files differnew file mode 100644 index 0000000..a616fcc --- /dev/null +++ b/chrome_frame/test/data/test_dlls/DummyCF/npchrome_frame.dll diff --git a/chrome_frame/test/data/test_dlls/README b/chrome_frame/test/data/test_dlls/README new file mode 100644 index 0000000..5db353c --- /dev/null +++ b/chrome_frame/test/data/test_dlls/README @@ -0,0 +1,7 @@ +This directory contains dummy DLLs intended to support testing of the module
+scanning code currently in Chrome Frame (in module_utils.cc at time of writing).
+
+The DLLs contain no code of mention and export a DllGetClassObject function.
+The only main difference between them is that they have different version
+numbers in the FileVersion and ProductVersion fields of their
+VS_VERSION_INFO resources.
\ No newline at end of file diff --git a/chrome_frame/test/data/test_dlls/TestDllNoCF/TestDll.dll b/chrome_frame/test/data/test_dlls/TestDllNoCF/TestDll.dll Binary files differnew file mode 100644 index 0000000..cbdffcc --- /dev/null +++ b/chrome_frame/test/data/test_dlls/TestDllNoCF/TestDll.dll diff --git a/chrome_frame/test/module_utils_unittest.cc b/chrome_frame/test/module_utils_unittest.cc new file mode 100644 index 0000000..88d52f7 --- /dev/null +++ b/chrome_frame/test/module_utils_unittest.cc @@ -0,0 +1,196 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This test requires loading a set of DLLs from the chrome_frame\test\data +// directory into the process and then inspecting them. As such, it is +// part of chrome_frame_tests.exe and not chrome_frame_unittests.exe which +// needs to run as a standalone test. No test is an island except for +// chrome_frame_unittests.exe. + +#include "chrome_frame/module_utils.h" +#include "testing/gtest/include/gtest/gtest.h" + +#include "base/logging.h" +#include "base/file_path.h" +#include "base/path_service.h" +#include "chrome_frame/test_utils.h" + +#include "chrome_tab.h" // NOLINT + +class ModuleUtilsTest : public testing::Test { + protected: + // Constructor + ModuleUtilsTest() {} + + // Returns the full path to the test DLL given a name. + virtual bool GetDllPath(const std::wstring& dll_name, std::wstring* path) { + if (!path) { + return false; + } + + FilePath test_path; + if (!PathService::Get(base::DIR_SOURCE_ROOT, &test_path)) { + return false; + } + + test_path = test_path.Append(L"chrome_frame") + .Append(L"test") + .Append(L"data") + .Append(L"test_dlls") + .Append(FilePath(dll_name)); + + *path = test_path.value(); + return true; + } + + // Loads the CF Dll and returns its path in |cf_dll_path|. + virtual bool LoadChromeFrameDll(std::wstring* cf_dll_path) { + DCHECK(cf_dll_path); + // Look for the CF dll in both the current directory and in servers. + FilePath dll_path = ScopedChromeFrameRegistrar::GetChromeFrameBuildPath(); + + bool success = false; + if (!dll_path.empty()) { + cf_dll_path_ = dll_path.value(); + HMODULE handle = LoadLibrary(cf_dll_path_.c_str()); + if (handle) { + hmodule_map_[cf_dll_path_] = handle; + *cf_dll_path = cf_dll_path_; + success = true; + } else { + LOG(ERROR) << "Failed to load test dll: " << dll_path.value(); + } + } + + return success; + } + + virtual bool LoadTestDll(const std::wstring& dll_name) { + bool success = false; + std::wstring dll_path; + if (GetDllPath(dll_name, &dll_path)) { + HMODULE handle = LoadLibrary(dll_path.c_str()); + if (handle) { + hmodule_map_[dll_name] = handle; + success = true; + } else { + LOG(ERROR) << "Failed to load test dll: " << dll_name; + } + } else { + LOG(ERROR) << "Failed to get dll path for " << dll_name; + } + return success; + } + + // Unload any DLLs we have loaded and make sure they stay unloaded. + virtual void TearDown() { + DllRedirector::PathToHModuleMap::const_iterator iter(hmodule_map_.begin()); + for (; iter != hmodule_map_.end(); ++iter) { + FreeLibrary(iter->second); + } + + // Check that the modules were actually unloaded (i.e. we had no dangling + // references). Do this after freeing all modules since they can have + // references to each other. + for (iter = hmodule_map_.begin(); iter != hmodule_map_.end(); ++iter) { + // The CF module gets pinned, so don't check that that is unloaded. + if (iter->first != cf_dll_path_) { + HMODULE temp_handle; + ASSERT_FALSE(GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, + reinterpret_cast<LPCTSTR>(iter->second), + &temp_handle)); + } + } + + hmodule_map_.clear(); + } + + DllRedirector::PathToHModuleMap hmodule_map_; + std::wstring cf_dll_path_; +}; + +// Tests that if we load a few versions of the same module that all export +// DllGetClassObject, that we correctly a) find a DllGetClassObject function +// pointer and b) find it in the right module. +TEST_F(ModuleUtilsTest, BasicTest) { + ASSERT_TRUE(LoadTestDll(L"3\\TestDll.dll")); + ASSERT_TRUE(LoadTestDll(L"2\\TestDll.dll")); + ASSERT_TRUE(LoadTestDll(L"1\\TestDll.dll")); + + DllRedirector redir; + redir.EnsureInitialized(L"TestDll.dll", CLSID_ChromeActiveDocument); + + LPFNGETCLASSOBJECT found_ptr = redir.get_dll_get_class_object_ptr(); + EXPECT_TRUE(found_ptr != NULL); + + LPFNGETCLASSOBJECT direct_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>( + GetProcAddress(hmodule_map_[L"1\\TestDll.dll"], + "DllGetClassObject")); + EXPECT_TRUE(direct_ptr != NULL); + + EXPECT_EQ(found_ptr, direct_ptr); +} + +// Tests that a DLL that does not return a class factory for a Chrome Frame +// guid even though it has a lower version string. +TEST_F(ModuleUtilsTest, NoCFDllTest) { + ASSERT_TRUE(LoadTestDll(L"1\\TestDll.dll")); + ASSERT_TRUE(LoadTestDll(L"TestDllNoCF\\TestDll.dll")); + + DllRedirector redir; + redir.EnsureInitialized(L"TestDll.dll", CLSID_ChromeActiveDocument); + + LPFNGETCLASSOBJECT found_ptr = redir.get_dll_get_class_object_ptr(); + EXPECT_TRUE(found_ptr != NULL); + + LPFNGETCLASSOBJECT direct_ptr = + reinterpret_cast<LPFNGETCLASSOBJECT>( + GetProcAddress(hmodule_map_[L"1\\TestDll.dll"], + "DllGetClassObject")); + EXPECT_TRUE(direct_ptr != NULL); + + EXPECT_EQ(found_ptr, direct_ptr); +} + +// Tests that this works with the actual CF dll. +TEST_F(ModuleUtilsTest, ChromeFrameDllTest) { + ASSERT_TRUE(LoadTestDll(L"DummyCF\\npchrome_frame.dll")); + std::wstring cf_dll_path; + ASSERT_TRUE(LoadChromeFrameDll(&cf_dll_path)); + ASSERT_TRUE(!cf_dll_path.empty()); + + DllRedirector redir; + redir.EnsureInitialized(L"npchrome_frame.dll", CLSID_ChromeActiveDocument); + + LPFNGETCLASSOBJECT found_ptr = redir.get_dll_get_class_object_ptr(); + EXPECT_TRUE(found_ptr != NULL); + + LPFNGETCLASSOBJECT direct_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>( + GetProcAddress(hmodule_map_[L"DummyCF\\npchrome_frame.dll"], + "DllGetClassObject")); + EXPECT_TRUE(direct_ptr != NULL); + + EXPECT_EQ(found_ptr, direct_ptr); + + // Now try asking for a ChromeActiveDocument using the non-dummy CF DLL + // handle and make sure that the delegation to the dummy module happens + // correctly. Use the bare guid to keep dependencies simple + const wchar_t kClsidChromeActiveDocument[] = + L"{3e1d0e7f-f5e3-44cc-aa6a-c0a637619ab8}"; + + LPFNGETCLASSOBJECT cf_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>( + GetProcAddress(hmodule_map_[cf_dll_path], + "DllGetClassObject")); + EXPECT_TRUE(cf_ptr != NULL); + + CLSID cf_clsid; + HRESULT hr = CLSIDFromString(kClsidChromeActiveDocument, &cf_clsid); + EXPECT_HRESULT_SUCCEEDED(hr); + + CComPtr<IClassFactory> class_factory; + DWORD result = cf_ptr(cf_clsid, IID_IClassFactory, + reinterpret_cast<void**>(&class_factory)); + + EXPECT_EQ(S_OK, result); +} |