diff options
-rw-r--r-- | chrome_frame/bho.cc | 9 | ||||
-rw-r--r-- | chrome_frame/protocol_sink_wrap.cc | 104 | ||||
-rw-r--r-- | chrome_frame/protocol_sink_wrap.h | 26 | ||||
-rw-r--r-- | chrome_frame/vtable_patch_manager.cc | 11 |
4 files changed, 120 insertions, 30 deletions
diff --git a/chrome_frame/bho.cc b/chrome_frame/bho.cc index e8c0374..9561cc1 100644 --- a/chrome_frame/bho.cc +++ b/chrome_frame/bho.cc @@ -18,7 +18,6 @@ #include "chrome_frame/utils.h" #include "chrome_frame/vtable_patch_manager.h" -const wchar_t kUrlMonDllName[] = L"urlmon.dll"; const wchar_t kPatchProtocols[] = L"PatchProtocols"; static const int kIBrowserServiceOnHttpEquivIndex = 30; @@ -217,8 +216,7 @@ void PatchHelper::InitializeAndPatchProtocolsIfNeeded() { bool patch_protocol = GetConfigBool(true, kPatchProtocols); if (patch_protocol) { - ProtocolSinkWrap::PatchProtocolHandler(kUrlMonDllName, CLSID_HttpProtocol); - ProtocolSinkWrap::PatchProtocolHandler(kUrlMonDllName, CLSID_HttpSProtocol); + ProtocolSinkWrap::PatchProtocolHandlers(); state_ = PATCH_PROTOCOL; } else { state_ = PATCH_IBROWSER; @@ -232,12 +230,9 @@ void PatchHelper::PatchBrowserService(IBrowserService* browser_service) { IBrowserService_PatchInfo); } -extern vtable_patch::MethodPatchInfo IInternetProtocol_PatchInfo[]; -extern vtable_patch::MethodPatchInfo IInternetProtocolEx_PatchInfo[]; void PatchHelper::UnpatchIfNeeded() { if (state_ == PATCH_PROTOCOL) { - vtable_patch::UnpatchInterfaceMethods(IInternetProtocol_PatchInfo); - vtable_patch::UnpatchInterfaceMethods(IInternetProtocolEx_PatchInfo); + ProtocolSinkWrap::UnpatchProtocolHandlers(); } else if (state_ == PATCH_IBROWSER_OK) { vtable_patch::UnpatchInterfaceMethods(IBrowserService_PatchInfo); } diff --git a/chrome_frame/protocol_sink_wrap.cc b/chrome_frame/protocol_sink_wrap.cc index a567e9f..5559e61 100644 --- a/chrome_frame/protocol_sink_wrap.cc +++ b/chrome_frame/protocol_sink_wrap.cc @@ -14,9 +14,8 @@ #include "base/string_util.h" #include "chrome_frame/utils.h" -#include "chrome_frame/vtable_patch_manager.h" -// BINDSTATUS_SERVER_MIMETYPEAVAILABLE == 54. Introduced in IE 8, so +// BINDSTATUS_SERVER_MIMETYPEAVAILABLE == 54. Introduced in IE 8, so // not in everyone's headers yet. See: // http://msdn.microsoft.com/en-us/library/ms775133(VS.85,loband).aspx #ifndef BINDSTATUS_SERVER_MIMETYPEAVAILABLE @@ -25,22 +24,36 @@ static const wchar_t* kChromeMimeType = L"application/chromepage"; static const char kTextHtmlMimeType[] = "text/html"; +const wchar_t kUrlMonDllName[] = L"urlmon.dll"; static const int kInternetProtocolStartIndex = 3; static const int kInternetProtocolReadIndex = 9; static const int kInternetProtocolStartExIndex = 13; +// TODO(ananta) +// We should avoid duplicate VTable declarations. BEGIN_VTABLE_PATCHES(IInternetProtocol) VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) END_VTABLE_PATCHES() +BEGIN_VTABLE_PATCHES(IInternetProtocolSecure) + VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) + VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) +END_VTABLE_PATCHES() + BEGIN_VTABLE_PATCHES(IInternetProtocolEx) VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, ProtocolSinkWrap::OnStartEx) END_VTABLE_PATCHES() +BEGIN_VTABLE_PATCHES(IInternetProtocolExSecure) + VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) + VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) + VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, ProtocolSinkWrap::OnStartEx) +END_VTABLE_PATCHES() + // // ProtocolSinkWrap implementation // @@ -64,12 +77,45 @@ ProtocolSinkWrap::~ProtocolSinkWrap() { DLOG(INFO) << "ProtocolSinkWrap: active sinks: " << sink_map_.size(); } -bool ProtocolSinkWrap::PatchProtocolHandler(const wchar_t* dll, - const CLSID& handler_clsid) { - HMODULE module = ::GetModuleHandle(dll); +bool ProtocolSinkWrap::PatchProtocolHandlers() { + HRESULT hr = PatchProtocolMethods(CLSID_HttpProtocol, + IInternetProtocol_PatchInfo, + IInternetProtocolEx_PatchInfo); + if (FAILED(hr)) { + NOTREACHED() << "Failed to patch IInternetProtocol interface." + << " Error: " << hr; + return false; + } + + hr = PatchProtocolMethods(CLSID_HttpSProtocol, + IInternetProtocolSecure_PatchInfo, + IInternetProtocolExSecure_PatchInfo); + if (FAILED(hr)) { + NOTREACHED() << "Failed to patch IInternetProtocol secure interface." + << " Error: " << hr; + return false; + } + + return true; +} + +void ProtocolSinkWrap::UnpatchProtocolHandlers() { + vtable_patch::UnpatchInterfaceMethods(IInternetProtocol_PatchInfo); + vtable_patch::UnpatchInterfaceMethods(IInternetProtocolEx_PatchInfo); + vtable_patch::UnpatchInterfaceMethods(IInternetProtocolSecure_PatchInfo); + vtable_patch::UnpatchInterfaceMethods(IInternetProtocolExSecure_PatchInfo); +} + +HRESULT ProtocolSinkWrap::CreateProtocolHandlerInstance( + const CLSID& clsid, IInternetProtocol** protocol) { + if (!protocol) { + return E_INVALIDARG; + } + + HMODULE module = ::GetModuleHandle(kUrlMonDllName); if (!module) { NOTREACHED() << "urlmon is not yet loaded. Error: " << GetLastError(); - return false; + return E_FAIL; } typedef HRESULT (WINAPI* DllGetClassObject_Fn)(REFCLSID, REFIID, LPVOID*); @@ -77,15 +123,15 @@ bool ProtocolSinkWrap::PatchProtocolHandler(const wchar_t* dll, ::GetProcAddress(module, "DllGetClassObject")); if (!fn) { NOTREACHED() << "DllGetClassObject not found in urlmon.dll"; - return false; + return E_FAIL; } ScopedComPtr<IClassFactory> protocol_class_factory; - HRESULT hr = fn(handler_clsid, IID_IClassFactory, + HRESULT hr = fn(clsid, IID_IClassFactory, reinterpret_cast<LPVOID*>(protocol_class_factory.Receive())); if (FAILED(hr)) { NOTREACHED() << "DllGetclassObject failed. Error: " << hr; - return false; + return hr; } ScopedComPtr<IInternetProtocol> handler_instance; @@ -94,19 +140,39 @@ bool ProtocolSinkWrap::PatchProtocolHandler(const wchar_t* dll, if (FAILED(hr)) { NOTREACHED() << "ClassFactory::CreateInstance failed for InternetProtocol." << " Error: " << hr; + } else { + *protocol = handler_instance.Detach(); + } + + return hr; +} + +HRESULT ProtocolSinkWrap::PatchProtocolMethods( + const CLSID& clsid_protocol, + vtable_patch::MethodPatchInfo* protocol_patch_info, + vtable_patch::MethodPatchInfo* protocol_ex_patch_info) { + if (!protocol_patch_info || !protocol_ex_patch_info) { + return E_INVALIDARG; + } + + ScopedComPtr<IInternetProtocol> http_protocol; + HRESULT hr = CreateProtocolHandlerInstance(clsid_protocol, + http_protocol.Receive()); + if (FAILED(hr)) { + NOTREACHED() << "ClassFactory::CreateInstance failed for InternetProtocol." + << " Error: " << hr; return false; } ScopedComPtr<IInternetProtocolEx> ipex; - ipex.QueryFrom(handler_instance); + ipex.QueryFrom(http_protocol); if (ipex) { - vtable_patch::PatchInterfaceMethods(ipex, IInternetProtocolEx_PatchInfo); + hr = vtable_patch::PatchInterfaceMethods(ipex, protocol_ex_patch_info); } else { - vtable_patch::PatchInterfaceMethods(handler_instance, - IInternetProtocol_PatchInfo); + hr = vtable_patch::PatchInterfaceMethods(http_protocol, + protocol_patch_info); } - - return true; + return hr; } // IInternetProtocol/Ex method implementation. @@ -264,10 +330,10 @@ STDMETHODIMP ProtocolSinkWrap::ReportData(DWORD flags, ULONG progress, // determination logic. // NOTE: We use the report_data_recursiveness_ variable to detect situations - // in which calls to ReportData are re-entrant (such as when the entire - // contents of a page fit inside a single packet). In these cases, we - // don't care about re-entrant calls beyond the second, and so we compare - // report_data_recursiveness_ inside the while loop, making sure we skip + // in which calls to ReportData are re-entrant (such as when the entire + // contents of a page fit inside a single packet). In these cases, we + // don't care about re-entrant calls beyond the second, and so we compare + // report_data_recursiveness_ inside the while loop, making sure we skip // what would otherwise be spurious calls to ReportProgress(). report_data_recursiveness_++; diff --git a/chrome_frame/protocol_sink_wrap.h b/chrome_frame/protocol_sink_wrap.h index e4f9cfb..bab018b 100644 --- a/chrome_frame/protocol_sink_wrap.h +++ b/chrome_frame/protocol_sink_wrap.h @@ -17,6 +17,7 @@ #include "base/scoped_comptr_win.h" #include "googleurl/src/gurl.h" #include "chrome_frame/ie8_types.h" +#include "chrome_frame/vtable_patch_manager.h" // Typedefs for IInternetProtocol and related methods that we patch. typedef HRESULT (STDMETHODCALLTYPE* InternetProtocol_Start_Fn)( @@ -85,8 +86,12 @@ END_COM_MAP() bool Initialize(IInternetProtocol* protocol, IInternetProtocolSink* original_sink, const wchar_t* url); - static bool PatchProtocolHandler(const wchar_t* dll, - const CLSID& handler_clsid); + // VTable patches the IInternetProtocol and IIntenetProtocolEx interface. + // Returns true on success. + static bool PatchProtocolHandlers(); + + // Unpatches the IInternetProtocol and IInternetProtocolEx interfaces. + static void UnpatchProtocolHandlers(); // IInternetProtocol/Ex patches. static HRESULT STDMETHODCALLTYPE OnStart(InternetProtocol_Start_Fn orig_start, @@ -183,6 +188,21 @@ END_COM_MAP() return renderer_type_; } + // Creates an instance of the specified protocol handler and returns the + // IInternetProtocol interface pointer. + // Returns S_OK on success. + static HRESULT CreateProtocolHandlerInstance(const CLSID& clsid, + IInternetProtocol** protocol); + + // Helper function for patching the VTable of the IInternetProtocol + // interface. It instantiates the object identified by the protocol_clsid + // parameter and patches its VTable. + // Returns S_OK on success. + static HRESULT PatchProtocolMethods( + const CLSID& protocol_clsid, + vtable_patch::MethodPatchInfo* protocol_patch_info, + vtable_patch::MethodPatchInfo* protocol_ex_patch_info); + // WARNING: Don't use GURL variables here. Please see // http://b/issue?id=2102171 for details. @@ -202,7 +222,7 @@ END_COM_MAP() HRESULT result_code_; DWORD result_error_; std::wstring result_text_; - // For tracking re-entrency and preventing duplicate Read()s from + // For tracking re-entrency and preventing duplicate Read()s from // distorting the outcome of ReportData. int report_data_recursiveness_; diff --git a/chrome_frame/vtable_patch_manager.cc b/chrome_frame/vtable_patch_manager.cc index 5f15158..0b9f79aa 100644 --- a/chrome_frame/vtable_patch_manager.cc +++ b/chrome_frame/vtable_patch_manager.cc @@ -30,6 +30,15 @@ HRESULT PatchInterfaceMethods(void* unknown, MethodPatchInfo* patches) { DCHECK(vtable); for (MethodPatchInfo* it = patches; it->index_ != -1; ++it) { + if (it->stub_ != NULL) { + // If this DCHECK fires it means that we are using the same VTable + // information to patch two different interfaces. + DCHECK(false); + DLOG(ERROR) << "Attempting to patch two different VTables with the " + << "same VTable information"; + continue; + } + PROC original_fn = vtable[it->index_]; FunctionStub* stub = FunctionStub::FromCode(original_fn); if (stub != NULL) { @@ -65,7 +74,7 @@ HRESULT PatchInterfaceMethods(void* unknown, MethodPatchInfo* patches) { HRESULT UnpatchInterfaceMethods(MethodPatchInfo* patches) { for (MethodPatchInfo* it = patches; it->index_ != -1; ++it) { if (it->stub_) { - DCHECK(it->stub_->absolute_target() == + DCHECK(it->stub_->absolute_target() == reinterpret_cast<uintptr_t>(it->method_)); // Modify the stub to just jump directly to the original function. it->stub_->BypassStub(reinterpret_cast<void*>(it->stub_->argument())); |