summaryrefslogtreecommitdiffstats
path: root/chrome_frame/module_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'chrome_frame/module_utils.cc')
-rw-r--r--chrome_frame/module_utils.cc175
1 files changed, 175 insertions, 0 deletions
diff --git a/chrome_frame/module_utils.cc b/chrome_frame/module_utils.cc
new file mode 100644
index 0000000..77cdd91
--- /dev/null
+++ b/chrome_frame/module_utils.cc
@@ -0,0 +1,175 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "chrome_frame/module_utils.h"
+
+#include <atlbase.h>
+#include <TlHelp32.h>
+
+#include "base/scoped_ptr.h"
+#include "base/file_version_info.h"
+#include "base/logging.h"
+#include "base/scoped_handle.h"
+#include "base/string_util.h"
+#include "base/version.h"
+
+DllRedirector::DllRedirector() : dcgo_ptr_(NULL), initialized_(false),
+ module_handle_(NULL) {}
+
+DllRedirector::~DllRedirector() {
+ if (module_handle_) {
+ FreeLibrary(module_handle_);
+ module_handle_ = NULL;
+ }
+}
+
+void DllRedirector::EnsureInitialized(const wchar_t* module_name,
+ REFCLSID clsid) {
+ if (!initialized_) {
+ initialized_ = true;
+ // Also sets module_handle_.
+ dcgo_ptr_ = GetDllGetClassObjectFromModuleName(module_name, clsid);
+ }
+}
+
+LPFNGETCLASSOBJECT DllRedirector::get_dll_get_class_object_ptr() const {
+ DCHECK(initialized_);
+ return dcgo_ptr_;
+}
+
+LPFNGETCLASSOBJECT DllRedirector::GetDllGetClassObjectFromModuleName(
+ const wchar_t* module_name, REFCLSID clsid) {
+ module_handle_ = NULL;
+ LPFNGETCLASSOBJECT proc_ptr = NULL;
+ HMODULE module_handle;
+ if (GetOldestNamedModuleHandle(module_name, clsid, &module_handle)) {
+ HMODULE this_module = reinterpret_cast<HMODULE>(&__ImageBase);
+ if (module_handle != this_module) {
+ proc_ptr = GetDllGetClassObjectPtr(module_handle);
+ if (proc_ptr) {
+ // Stash away the module handle in module_handle_ so that it will be
+ // automatically closed when we get destroyed. GetDllGetClassObjectPtr
+ // above will have incremented the module's ref count.
+ module_handle_ = module_handle;
+ }
+ } else {
+ LOG(INFO) << "Module Scan: DllGetClassObject found in current module.";
+ }
+ }
+
+ return proc_ptr;
+}
+
+bool DllRedirector::GetOldestNamedModuleHandle(const std::wstring& module_name,
+ REFCLSID clsid,
+ HMODULE* oldest_module_handle) {
+ DCHECK(oldest_module_handle);
+
+ ScopedHandle snapshot(CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, 0));
+ if (snapshot == INVALID_HANDLE_VALUE) {
+ LOG(ERROR) << "Could not create module snapshot!";
+ return false;
+ }
+
+ bool success = false;
+ PathToHModuleMap map;
+
+ // First get the list of module paths, and save the full path to base address
+ // mapping.
+ MODULEENTRY32W module_entry;
+ module_entry.dwSize = sizeof(module_entry);
+ BOOL cont = Module32FirstW(snapshot, &module_entry);
+ while (cont) {
+ if (!lstrcmpi(module_entry.szModule, module_name.c_str())) {
+ std::wstring full_path(module_entry.szExePath);
+ map[full_path] = module_entry.hModule;
+ }
+ cont = Module32NextW(snapshot, &module_entry);
+ }
+
+ // Next, enumerate the map and find the oldest version of the module.
+ // (check if the map is of size 1 first)
+ if (!map.empty()) {
+ if (map.size() == 1) {
+ *oldest_module_handle = map.begin()->second;
+ } else {
+ *oldest_module_handle = GetHandleOfOldestModule(map, clsid);
+ }
+
+ if (*oldest_module_handle != NULL) {
+ success = true;
+ }
+ } else {
+ LOG(INFO) << "Module Scan: No modules named " << module_name
+ << " were found.";
+ }
+
+ return success;
+}
+
+HMODULE DllRedirector::GetHandleOfOldestModule(const PathToHModuleMap& map,
+ REFCLSID clsid) {
+ HMODULE oldest_module = NULL;
+ scoped_ptr<Version> min_version(
+ Version::GetVersionFromString("999.999.999.999"));
+
+ PathToHModuleMap::const_iterator map_iter(map.begin());
+ for (; map_iter != map.end(); ++map_iter) {
+ // First check that either we are in the current module or that the DLL
+ // returns a class factory for our clsid.
+ bool current_module =
+ (map_iter->second == reinterpret_cast<HMODULE>(&__ImageBase));
+ bool gco_succeeded = false;
+ if (!current_module) {
+ LPFNGETCLASSOBJECT dgco_ptr = GetDllGetClassObjectPtr(map_iter->second);
+ if (dgco_ptr) {
+ {
+ CComPtr<IClassFactory> class_factory;
+ HRESULT hr = dgco_ptr(clsid, IID_IClassFactory,
+ reinterpret_cast<void**>(&class_factory));
+ gco_succeeded = SUCCEEDED(hr) && class_factory != NULL;
+ }
+ // Release the module ref count we picked up in GetDllGetClassObjectPtr.
+ FreeLibrary(map_iter->second);
+ }
+ }
+
+ if (current_module || gco_succeeded) {
+ // Then check that the version is less than we've already found:
+ scoped_ptr<FileVersionInfo> version_info(
+ FileVersionInfo::CreateFileVersionInfo(map_iter->first));
+ scoped_ptr<Version> version(
+ Version::GetVersionFromString(version_info->file_version()));
+ if (version->CompareTo(*min_version.get()) < 0) {
+ oldest_module = map_iter->second;
+ min_version.reset(version.release());
+ }
+ }
+ }
+
+ return oldest_module;
+}
+
+LPFNGETCLASSOBJECT DllRedirector::GetDllGetClassObjectPtr(HMODULE module) {
+ LPFNGETCLASSOBJECT proc_ptr = NULL;
+ HMODULE temp_handle = 0;
+ // Increment the module ref count while we have an pointer to its
+ // DllGetClassObject function.
+ if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
+ reinterpret_cast<LPCTSTR>(module),
+ &temp_handle)) {
+ proc_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>(
+ GetProcAddress(temp_handle, "DllGetClassObject"));
+ if (!proc_ptr) {
+ FreeLibrary(temp_handle);
+ LOG(ERROR) << "Module Scan: Couldn't get address of "
+ << "DllGetClassObject: "
+ << GetLastError();
+ }
+ } else {
+ LOG(ERROR) << "Module Scan: Could not increment module count: "
+ << GetLastError();
+ }
+ return proc_ptr;
+}