summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--chrome_frame/bho.cc9
-rw-r--r--chrome_frame/protocol_sink_wrap.cc104
-rw-r--r--chrome_frame/protocol_sink_wrap.h26
-rw-r--r--chrome_frame/vtable_patch_manager.cc11
4 files changed, 120 insertions, 30 deletions
diff --git a/chrome_frame/bho.cc b/chrome_frame/bho.cc
index e8c0374..9561cc1 100644
--- a/chrome_frame/bho.cc
+++ b/chrome_frame/bho.cc
@@ -18,7 +18,6 @@
#include "chrome_frame/utils.h"
#include "chrome_frame/vtable_patch_manager.h"
-const wchar_t kUrlMonDllName[] = L"urlmon.dll";
const wchar_t kPatchProtocols[] = L"PatchProtocols";
static const int kIBrowserServiceOnHttpEquivIndex = 30;
@@ -217,8 +216,7 @@ void PatchHelper::InitializeAndPatchProtocolsIfNeeded() {
bool patch_protocol = GetConfigBool(true, kPatchProtocols);
if (patch_protocol) {
- ProtocolSinkWrap::PatchProtocolHandler(kUrlMonDllName, CLSID_HttpProtocol);
- ProtocolSinkWrap::PatchProtocolHandler(kUrlMonDllName, CLSID_HttpSProtocol);
+ ProtocolSinkWrap::PatchProtocolHandlers();
state_ = PATCH_PROTOCOL;
} else {
state_ = PATCH_IBROWSER;
@@ -232,12 +230,9 @@ void PatchHelper::PatchBrowserService(IBrowserService* browser_service) {
IBrowserService_PatchInfo);
}
-extern vtable_patch::MethodPatchInfo IInternetProtocol_PatchInfo[];
-extern vtable_patch::MethodPatchInfo IInternetProtocolEx_PatchInfo[];
void PatchHelper::UnpatchIfNeeded() {
if (state_ == PATCH_PROTOCOL) {
- vtable_patch::UnpatchInterfaceMethods(IInternetProtocol_PatchInfo);
- vtable_patch::UnpatchInterfaceMethods(IInternetProtocolEx_PatchInfo);
+ ProtocolSinkWrap::UnpatchProtocolHandlers();
} else if (state_ == PATCH_IBROWSER_OK) {
vtable_patch::UnpatchInterfaceMethods(IBrowserService_PatchInfo);
}
diff --git a/chrome_frame/protocol_sink_wrap.cc b/chrome_frame/protocol_sink_wrap.cc
index a567e9f..5559e61 100644
--- a/chrome_frame/protocol_sink_wrap.cc
+++ b/chrome_frame/protocol_sink_wrap.cc
@@ -14,9 +14,8 @@
#include "base/string_util.h"
#include "chrome_frame/utils.h"
-#include "chrome_frame/vtable_patch_manager.h"
-// BINDSTATUS_SERVER_MIMETYPEAVAILABLE == 54. Introduced in IE 8, so
+// BINDSTATUS_SERVER_MIMETYPEAVAILABLE == 54. Introduced in IE 8, so
// not in everyone's headers yet. See:
// http://msdn.microsoft.com/en-us/library/ms775133(VS.85,loband).aspx
#ifndef BINDSTATUS_SERVER_MIMETYPEAVAILABLE
@@ -25,22 +24,36 @@
static const wchar_t* kChromeMimeType = L"application/chromepage";
static const char kTextHtmlMimeType[] = "text/html";
+const wchar_t kUrlMonDllName[] = L"urlmon.dll";
static const int kInternetProtocolStartIndex = 3;
static const int kInternetProtocolReadIndex = 9;
static const int kInternetProtocolStartExIndex = 13;
+// TODO(ananta)
+// We should avoid duplicate VTable declarations.
BEGIN_VTABLE_PATCHES(IInternetProtocol)
VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart)
VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead)
END_VTABLE_PATCHES()
+BEGIN_VTABLE_PATCHES(IInternetProtocolSecure)
+ VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart)
+ VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead)
+END_VTABLE_PATCHES()
+
BEGIN_VTABLE_PATCHES(IInternetProtocolEx)
VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart)
VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead)
VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, ProtocolSinkWrap::OnStartEx)
END_VTABLE_PATCHES()
+BEGIN_VTABLE_PATCHES(IInternetProtocolExSecure)
+ VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart)
+ VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead)
+ VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, ProtocolSinkWrap::OnStartEx)
+END_VTABLE_PATCHES()
+
//
// ProtocolSinkWrap implementation
//
@@ -64,12 +77,45 @@ ProtocolSinkWrap::~ProtocolSinkWrap() {
DLOG(INFO) << "ProtocolSinkWrap: active sinks: " << sink_map_.size();
}
-bool ProtocolSinkWrap::PatchProtocolHandler(const wchar_t* dll,
- const CLSID& handler_clsid) {
- HMODULE module = ::GetModuleHandle(dll);
+bool ProtocolSinkWrap::PatchProtocolHandlers() {
+ HRESULT hr = PatchProtocolMethods(CLSID_HttpProtocol,
+ IInternetProtocol_PatchInfo,
+ IInternetProtocolEx_PatchInfo);
+ if (FAILED(hr)) {
+ NOTREACHED() << "Failed to patch IInternetProtocol interface."
+ << " Error: " << hr;
+ return false;
+ }
+
+ hr = PatchProtocolMethods(CLSID_HttpSProtocol,
+ IInternetProtocolSecure_PatchInfo,
+ IInternetProtocolExSecure_PatchInfo);
+ if (FAILED(hr)) {
+ NOTREACHED() << "Failed to patch IInternetProtocol secure interface."
+ << " Error: " << hr;
+ return false;
+ }
+
+ return true;
+}
+
+void ProtocolSinkWrap::UnpatchProtocolHandlers() {
+ vtable_patch::UnpatchInterfaceMethods(IInternetProtocol_PatchInfo);
+ vtable_patch::UnpatchInterfaceMethods(IInternetProtocolEx_PatchInfo);
+ vtable_patch::UnpatchInterfaceMethods(IInternetProtocolSecure_PatchInfo);
+ vtable_patch::UnpatchInterfaceMethods(IInternetProtocolExSecure_PatchInfo);
+}
+
+HRESULT ProtocolSinkWrap::CreateProtocolHandlerInstance(
+ const CLSID& clsid, IInternetProtocol** protocol) {
+ if (!protocol) {
+ return E_INVALIDARG;
+ }
+
+ HMODULE module = ::GetModuleHandle(kUrlMonDllName);
if (!module) {
NOTREACHED() << "urlmon is not yet loaded. Error: " << GetLastError();
- return false;
+ return E_FAIL;
}
typedef HRESULT (WINAPI* DllGetClassObject_Fn)(REFCLSID, REFIID, LPVOID*);
@@ -77,15 +123,15 @@ bool ProtocolSinkWrap::PatchProtocolHandler(const wchar_t* dll,
::GetProcAddress(module, "DllGetClassObject"));
if (!fn) {
NOTREACHED() << "DllGetClassObject not found in urlmon.dll";
- return false;
+ return E_FAIL;
}
ScopedComPtr<IClassFactory> protocol_class_factory;
- HRESULT hr = fn(handler_clsid, IID_IClassFactory,
+ HRESULT hr = fn(clsid, IID_IClassFactory,
reinterpret_cast<LPVOID*>(protocol_class_factory.Receive()));
if (FAILED(hr)) {
NOTREACHED() << "DllGetclassObject failed. Error: " << hr;
- return false;
+ return hr;
}
ScopedComPtr<IInternetProtocol> handler_instance;
@@ -94,19 +140,39 @@ bool ProtocolSinkWrap::PatchProtocolHandler(const wchar_t* dll,
if (FAILED(hr)) {
NOTREACHED() << "ClassFactory::CreateInstance failed for InternetProtocol."
<< " Error: " << hr;
+ } else {
+ *protocol = handler_instance.Detach();
+ }
+
+ return hr;
+}
+
+HRESULT ProtocolSinkWrap::PatchProtocolMethods(
+ const CLSID& clsid_protocol,
+ vtable_patch::MethodPatchInfo* protocol_patch_info,
+ vtable_patch::MethodPatchInfo* protocol_ex_patch_info) {
+ if (!protocol_patch_info || !protocol_ex_patch_info) {
+ return E_INVALIDARG;
+ }
+
+ ScopedComPtr<IInternetProtocol> http_protocol;
+ HRESULT hr = CreateProtocolHandlerInstance(clsid_protocol,
+ http_protocol.Receive());
+ if (FAILED(hr)) {
+ NOTREACHED() << "ClassFactory::CreateInstance failed for InternetProtocol."
+ << " Error: " << hr;
return false;
}
ScopedComPtr<IInternetProtocolEx> ipex;
- ipex.QueryFrom(handler_instance);
+ ipex.QueryFrom(http_protocol);
if (ipex) {
- vtable_patch::PatchInterfaceMethods(ipex, IInternetProtocolEx_PatchInfo);
+ hr = vtable_patch::PatchInterfaceMethods(ipex, protocol_ex_patch_info);
} else {
- vtable_patch::PatchInterfaceMethods(handler_instance,
- IInternetProtocol_PatchInfo);
+ hr = vtable_patch::PatchInterfaceMethods(http_protocol,
+ protocol_patch_info);
}
-
- return true;
+ return hr;
}
// IInternetProtocol/Ex method implementation.
@@ -264,10 +330,10 @@ STDMETHODIMP ProtocolSinkWrap::ReportData(DWORD flags, ULONG progress,
// determination logic.
// NOTE: We use the report_data_recursiveness_ variable to detect situations
- // in which calls to ReportData are re-entrant (such as when the entire
- // contents of a page fit inside a single packet). In these cases, we
- // don't care about re-entrant calls beyond the second, and so we compare
- // report_data_recursiveness_ inside the while loop, making sure we skip
+ // in which calls to ReportData are re-entrant (such as when the entire
+ // contents of a page fit inside a single packet). In these cases, we
+ // don't care about re-entrant calls beyond the second, and so we compare
+ // report_data_recursiveness_ inside the while loop, making sure we skip
// what would otherwise be spurious calls to ReportProgress().
report_data_recursiveness_++;
diff --git a/chrome_frame/protocol_sink_wrap.h b/chrome_frame/protocol_sink_wrap.h
index e4f9cfb..bab018b 100644
--- a/chrome_frame/protocol_sink_wrap.h
+++ b/chrome_frame/protocol_sink_wrap.h
@@ -17,6 +17,7 @@
#include "base/scoped_comptr_win.h"
#include "googleurl/src/gurl.h"
#include "chrome_frame/ie8_types.h"
+#include "chrome_frame/vtable_patch_manager.h"
// Typedefs for IInternetProtocol and related methods that we patch.
typedef HRESULT (STDMETHODCALLTYPE* InternetProtocol_Start_Fn)(
@@ -85,8 +86,12 @@ END_COM_MAP()
bool Initialize(IInternetProtocol* protocol,
IInternetProtocolSink* original_sink, const wchar_t* url);
- static bool PatchProtocolHandler(const wchar_t* dll,
- const CLSID& handler_clsid);
+ // VTable patches the IInternetProtocol and IIntenetProtocolEx interface.
+ // Returns true on success.
+ static bool PatchProtocolHandlers();
+
+ // Unpatches the IInternetProtocol and IInternetProtocolEx interfaces.
+ static void UnpatchProtocolHandlers();
// IInternetProtocol/Ex patches.
static HRESULT STDMETHODCALLTYPE OnStart(InternetProtocol_Start_Fn orig_start,
@@ -183,6 +188,21 @@ END_COM_MAP()
return renderer_type_;
}
+ // Creates an instance of the specified protocol handler and returns the
+ // IInternetProtocol interface pointer.
+ // Returns S_OK on success.
+ static HRESULT CreateProtocolHandlerInstance(const CLSID& clsid,
+ IInternetProtocol** protocol);
+
+ // Helper function for patching the VTable of the IInternetProtocol
+ // interface. It instantiates the object identified by the protocol_clsid
+ // parameter and patches its VTable.
+ // Returns S_OK on success.
+ static HRESULT PatchProtocolMethods(
+ const CLSID& protocol_clsid,
+ vtable_patch::MethodPatchInfo* protocol_patch_info,
+ vtable_patch::MethodPatchInfo* protocol_ex_patch_info);
+
// WARNING: Don't use GURL variables here. Please see
// http://b/issue?id=2102171 for details.
@@ -202,7 +222,7 @@ END_COM_MAP()
HRESULT result_code_;
DWORD result_error_;
std::wstring result_text_;
- // For tracking re-entrency and preventing duplicate Read()s from
+ // For tracking re-entrency and preventing duplicate Read()s from
// distorting the outcome of ReportData.
int report_data_recursiveness_;
diff --git a/chrome_frame/vtable_patch_manager.cc b/chrome_frame/vtable_patch_manager.cc
index 5f15158..0b9f79aa 100644
--- a/chrome_frame/vtable_patch_manager.cc
+++ b/chrome_frame/vtable_patch_manager.cc
@@ -30,6 +30,15 @@ HRESULT PatchInterfaceMethods(void* unknown, MethodPatchInfo* patches) {
DCHECK(vtable);
for (MethodPatchInfo* it = patches; it->index_ != -1; ++it) {
+ if (it->stub_ != NULL) {
+ // If this DCHECK fires it means that we are using the same VTable
+ // information to patch two different interfaces.
+ DCHECK(false);
+ DLOG(ERROR) << "Attempting to patch two different VTables with the "
+ << "same VTable information";
+ continue;
+ }
+
PROC original_fn = vtable[it->index_];
FunctionStub* stub = FunctionStub::FromCode(original_fn);
if (stub != NULL) {
@@ -65,7 +74,7 @@ HRESULT PatchInterfaceMethods(void* unknown, MethodPatchInfo* patches) {
HRESULT UnpatchInterfaceMethods(MethodPatchInfo* patches) {
for (MethodPatchInfo* it = patches; it->index_ != -1; ++it) {
if (it->stub_) {
- DCHECK(it->stub_->absolute_target() ==
+ DCHECK(it->stub_->absolute_target() ==
reinterpret_cast<uintptr_t>(it->method_));
// Modify the stub to just jump directly to the original function.
it->stub_->BypassStub(reinterpret_cast<void*>(it->stub_->argument()));