diff options
Diffstat (limited to 'chrome_frame')
-rw-r--r-- | chrome_frame/bho.cc | 12 | ||||
-rw-r--r-- | chrome_frame/http_negotiate.cc | 129 | ||||
-rw-r--r-- | chrome_frame/http_negotiate.h | 47 | ||||
-rw-r--r-- | chrome_frame/test/http_negotiate_unittest.cc | 106 | ||||
-rw-r--r-- | chrome_frame/test/navigation_test.cc | 9 |
5 files changed, 300 insertions, 3 deletions
diff --git a/chrome_frame/bho.cc b/chrome_frame/bho.cc index 7fbc473..f7e8129 100644 --- a/chrome_frame/bho.cc +++ b/chrome_frame/bho.cc @@ -319,6 +319,12 @@ bool PatchHelper::InitializeAndPatchProtocolsIfNeeded() { if (state_ == UNKNOWN) { g_trans_hooks.InstallHooks(); + // IE9 sends the short user agent by default. To enable websites to + // identify and send content specific to chrome frame we need the + // negotiate patch which adds the user agent to outgoing requests. + if (GetIEVersion() == IE_9) { + HttpNegotiatePatch::Initialize(); + } state_ = PATCH_PROTOCOL; ret = true; } @@ -339,9 +345,9 @@ void PatchHelper::PatchBrowserService(IBrowserService* browser_service) { void PatchHelper::UnpatchIfNeeded() { if (state_ == PATCH_PROTOCOL) { g_trans_hooks.RevertHooks(); - } else if (state_ == PATCH_IBROWSER) { - vtable_patch::UnpatchInterfaceMethods(IBrowserService_PatchInfo); - MonikerPatch::Uninitialize(); + if (GetIEVersion() == IE_9) { + HttpNegotiatePatch::Uninitialize(); + } } state_ = UNKNOWN; } diff --git a/chrome_frame/http_negotiate.cc b/chrome_frame/http_negotiate.cc index 85701c3..b593f28 100644 --- a/chrome_frame/http_negotiate.cc +++ b/chrome_frame/http_negotiate.cc @@ -32,6 +32,56 @@ const char kLowerCaseUserAgent[] = "user-agent"; // TODO(robertshield): Remove this once we update our SDK version. const int LOCAL_BINDSTATUS_SERVER_MIMETYPEAVAILABLE = 54; +static const int kHttpNegotiateBeginningTransactionIndex = 3; + +BEGIN_VTABLE_PATCHES(IHttpNegotiate) + VTABLE_PATCH_ENTRY(kHttpNegotiateBeginningTransactionIndex, + HttpNegotiatePatch::BeginningTransaction) +END_VTABLE_PATCHES() + +namespace { + +class SimpleBindStatusCallback : public CComObjectRootEx<CComSingleThreadModel>, + public IBindStatusCallback { + public: + BEGIN_COM_MAP(SimpleBindStatusCallback) + COM_INTERFACE_ENTRY(IBindStatusCallback) + END_COM_MAP() + + // IBindStatusCallback implementation + STDMETHOD(OnStartBinding)(DWORD reserved, IBinding* binding) { + return E_NOTIMPL; + } + + STDMETHOD(GetPriority)(LONG* priority) { + return E_NOTIMPL; + } + STDMETHOD(OnLowResource)(DWORD reserved) { + return E_NOTIMPL; + } + + STDMETHOD(OnProgress)(ULONG progress, ULONG max_progress, + ULONG status_code, LPCWSTR status_text) { + return E_NOTIMPL; + } + STDMETHOD(OnStopBinding)(HRESULT result, LPCWSTR error) { + return E_NOTIMPL; + } + + STDMETHOD(GetBindInfo)(DWORD* bind_flags, BINDINFO* bind_info) { + return E_NOTIMPL; + } + + STDMETHOD(OnDataAvailable)(DWORD flags, DWORD size, FORMATETC* formatetc, + STGMEDIUM* storage) { + return E_NOTIMPL; + } + STDMETHOD(OnObjectAvailable)(REFIID iid, IUnknown* object) { + return E_NOTIMPL; + } +}; +} // end namespace + std::string AppendCFUserAgentString(LPCWSTR headers, LPCWSTR additional_headers) { using net::HttpUtil; @@ -107,3 +157,82 @@ std::string ReplaceOrAddUserAgent(LPCWSTR headers, return new_headers; } +HttpNegotiatePatch::HttpNegotiatePatch() { +} + +HttpNegotiatePatch::~HttpNegotiatePatch() { +} + +// static +bool HttpNegotiatePatch::Initialize() { + if (IS_PATCHED(IHttpNegotiate)) { + DLOG(WARNING) << __FUNCTION__ << " called more than once."; + return true; + } + // Use our SimpleBindStatusCallback class as we need a temporary object that + // implements IBindStatusCallback. + CComObjectStackEx<SimpleBindStatusCallback> request; + ScopedComPtr<IBindCtx> bind_ctx; + HRESULT hr = CreateAsyncBindCtx(0, &request, NULL, bind_ctx.Receive()); + DCHECK(SUCCEEDED(hr)) << "CreateAsyncBindCtx"; + if (bind_ctx) { + ScopedComPtr<IUnknown> bscb_holder; + bind_ctx->GetObjectParam(L"_BSCB_Holder_", bscb_holder.Receive()); + if (bscb_holder) { + hr = PatchHttpNegotiate(bscb_holder); + } else { + NOTREACHED() << "Failed to get _BSCB_Holder_"; + hr = E_UNEXPECTED; + } + bind_ctx.Release(); + } + + return SUCCEEDED(hr); +} + +// static +void HttpNegotiatePatch::Uninitialize() { + vtable_patch::UnpatchInterfaceMethods(IHttpNegotiate_PatchInfo); +} + +// static +HRESULT HttpNegotiatePatch::PatchHttpNegotiate(IUnknown* to_patch) { + DCHECK(to_patch); + DCHECK_IS_NOT_PATCHED(IHttpNegotiate); + + ScopedComPtr<IHttpNegotiate> http; + HRESULT hr = http.QueryFrom(to_patch); + if (FAILED(hr)) { + hr = DoQueryService(IID_IHttpNegotiate, to_patch, http.Receive()); + } + + if (http) { + hr = vtable_patch::PatchInterfaceMethods(http, IHttpNegotiate_PatchInfo); + DLOG_IF(ERROR, FAILED(hr)) + << base::StringPrintf("HttpNegotiate patch failed 0x%08X", hr); + } else { + DLOG(WARNING) + << base::StringPrintf("IHttpNegotiate not supported 0x%08X", hr); + } + return hr; +} + +// static +HRESULT HttpNegotiatePatch::BeginningTransaction( + IHttpNegotiate_BeginningTransaction_Fn original, IHttpNegotiate* me, + LPCWSTR url, LPCWSTR headers, DWORD reserved, LPWSTR* additional_headers) { + DVLOG(1) << __FUNCTION__ << " " << url << " headers:\n" << headers; + + HRESULT hr = original(me, url, headers, reserved, additional_headers); + + if (FAILED(hr)) { + DLOG(WARNING) << __FUNCTION__ << " Delegate returned an error"; + return hr; + } + std::string updated(AppendCFUserAgentString(headers, *additional_headers)); + *additional_headers = reinterpret_cast<wchar_t*>(::CoTaskMemRealloc( + *additional_headers, (updated.length() + 1) * sizeof(wchar_t))); + lstrcpyW(*additional_headers, ASCIIToWide(updated).c_str()); + return S_OK; +} + diff --git a/chrome_frame/http_negotiate.h b/chrome_frame/http_negotiate.h index 151ad23..eb17c7f 100644 --- a/chrome_frame/http_negotiate.h +++ b/chrome_frame/http_negotiate.h @@ -12,6 +12,53 @@ #include "base/basictypes.h" #include "base/scoped_comptr_win.h" +// Typedefs for IHttpNegotiate methods. +typedef HRESULT (STDMETHODCALLTYPE* IHttpNegotiate_BeginningTransaction_Fn)( + IHttpNegotiate* me, LPCWSTR url, LPCWSTR headers, DWORD reserved, + LPWSTR* additional_headers); +typedef HRESULT (STDMETHODCALLTYPE* IHttpNegotiate_OnResponse_Fn)( + IHttpNegotiate* me, DWORD response_code, LPCWSTR response_header, + LPCWSTR request_header, LPWSTR* additional_request_headers); + +// Typedefs for IBindStatusCallback methods. +typedef HRESULT (STDMETHODCALLTYPE* IBindStatusCallback_StartBinding_Fn)( + IBindStatusCallback* me, DWORD reserved, IBinding *binding); + +// Typedefs for IInternetProtocolSink methods. +typedef HRESULT (STDMETHODCALLTYPE* IInternetProtocolSink_ReportProgress_Fn)( + IInternetProtocolSink* me, ULONG status_code, LPCWSTR status_text); + +// Patches methods of urlmon's IHttpNegotiate implementation for the purposes +// of adding to the http user agent header. + +// Also patches one of the IBindStatusCallback implementations in urlmon to pick +// up an IBinding during the StartBinding call. The IBinding implementor then +// gets a patch applied to its IInternetProtocolSink's ReportProgress method. +// The patched is there so that the reporting of the MIME type to the IBinding +// implementor can be changed if an X-Chrome-Frame HTTP header is present +// in the response headers. If anyone can suggest a more straightforward way of +// doing this, I would be eternally grateful. +class HttpNegotiatePatch { + // class is not to be instantiated atm. + HttpNegotiatePatch(); + ~HttpNegotiatePatch(); + + public: + static bool Initialize(); + static void Uninitialize(); + + // IHttpNegotiate patch methods + static STDMETHODIMP BeginningTransaction( + IHttpNegotiate_BeginningTransaction_Fn original, IHttpNegotiate* me, + LPCWSTR url, LPCWSTR headers, DWORD reserved, LPWSTR* additional_headers); + + protected: + static HRESULT PatchHttpNegotiate(IUnknown* to_patch); + + private: + DISALLOW_COPY_AND_ASSIGN(HttpNegotiatePatch); +}; + // From the latest urlmon.h. Symbol name prepended with LOCAL_ to // avoid conflict (and therefore build errors) for those building with // a newer Windows SDK. diff --git a/chrome_frame/test/http_negotiate_unittest.cc b/chrome_frame/test/http_negotiate_unittest.cc index 857197e..c56034b 100644 --- a/chrome_frame/test/http_negotiate_unittest.cc +++ b/chrome_frame/test/http_negotiate_unittest.cc @@ -16,6 +16,112 @@ #include "gtest/gtest.h" #include "gmock/gmock.h" +class HttpNegotiateTest : public testing::Test { + protected: + HttpNegotiateTest() { + } +}; + +class TestHttpNegotiate + : public CComObjectRootEx<CComMultiThreadModel>, + public IHttpNegotiate { + public: + TestHttpNegotiate() + : beginning_transaction_ret_(S_OK), additional_headers_(NULL) { + } + +BEGIN_COM_MAP(TestHttpNegotiate) + COM_INTERFACE_ENTRY(IHttpNegotiate) +END_COM_MAP() + STDMETHOD(BeginningTransaction)(LPCWSTR url, LPCWSTR headers, // NOLINT + DWORD reserved, // NOLINT + LPWSTR* additional_headers) { // NOLINT + if (additional_headers_) { + int len = lstrlenW(additional_headers_); + len++; + *additional_headers = reinterpret_cast<wchar_t*>( + ::CoTaskMemAlloc(len * sizeof(wchar_t))); + lstrcpyW(*additional_headers, additional_headers_); + } + return beginning_transaction_ret_; + } + + STDMETHOD(OnResponse)(DWORD response_code, LPCWSTR response_header, + LPCWSTR request_header, + LPWSTR* additional_request_headers) { + return S_OK; + } + + HRESULT beginning_transaction_ret_; + const wchar_t* additional_headers_; +}; + +TEST_F(HttpNegotiateTest, BeginningTransaction) { + static const int kBeginningTransactionIndex = 3; + CComObjectStackEx<TestHttpNegotiate> test_http; + IHttpNegotiate_BeginningTransaction_Fn original = + reinterpret_cast<IHttpNegotiate_BeginningTransaction_Fn>( + (*reinterpret_cast<void***>( + static_cast<IHttpNegotiate*>( + &test_http)))[kBeginningTransactionIndex]); + + std::wstring cf_ua( + ASCIIToWide(http_utils::GetDefaultUserAgentHeaderWithCFTag())); + std::wstring cf_tag( + ASCIIToWide(http_utils::GetChromeFrameUserAgent())); + + EXPECT_NE(std::wstring::npos, cf_ua.find(cf_tag)); + + struct TestCase { + const std::wstring original_headers_; + const std::wstring delegate_additional_; + const std::wstring expected_additional_; + HRESULT delegate_return_value_; + } test_cases[] = { + { L"Accept: */*\r\n", + L"", + cf_ua + L"\r\n", + S_OK }, + { L"Accept: */*\r\n", + L"", + L"", + E_OUTOFMEMORY }, + { L"", + L"Accept: */*\r\n", + L"Accept: */*\r\n" + cf_ua + L"\r\n", + S_OK }, + { L"User-Agent: Bingo/1.0\r\n", + L"", + L"User-Agent: Bingo/1.0 " + cf_tag + L"\r\n", + S_OK }, + { L"User-Agent: NotMe/1.0\r\n", + L"User-Agent: MeMeMe/1.0\r\n", + L"User-Agent: MeMeMe/1.0 " + cf_tag + L"\r\n", + S_OK }, + { L"", + L"User-Agent: MeMeMe/1.0\r\n", + L"User-Agent: MeMeMe/1.0 " + cf_tag + L"\r\n", + S_OK }, + }; + + for (int i = 0; i < arraysize(test_cases); ++i) { + TestCase& test = test_cases[i]; + wchar_t* additional = NULL; + test_http.beginning_transaction_ret_ = test.delegate_return_value_; + test_http.additional_headers_ = test.delegate_additional_.c_str(); + HttpNegotiatePatch::BeginningTransaction(original, &test_http, + L"http://www.google.com", test.original_headers_.c_str(), 0, + &additional); + EXPECT_TRUE(additional != NULL); + + if (additional) { + // Check against the expected additional headers. + EXPECT_EQ(test.expected_additional_, std::wstring(additional)); + ::CoTaskMemFree(additional); + } + } +} + class TestInternetProtocolSink : public CComObjectRootEx<CComMultiThreadModel>, public IInternetProtocolSink { diff --git a/chrome_frame/test/navigation_test.cc b/chrome_frame/test/navigation_test.cc index 7e157ee..ff47ba5 100644 --- a/chrome_frame/test/navigation_test.cc +++ b/chrome_frame/test/navigation_test.cc @@ -1133,8 +1133,17 @@ TEST_P(FullTabNavigationTest, RefreshContentsUATest) { bool in_cf = GetParam().invokes_cf(); if (in_cf) { headers.append("X-UA-Compatible: chrome=1\r\n"); + } else { + if (GetInstalledIEVersion() == IE_9) { + LOG(ERROR) << "Test disabled for IE9"; + return; + } } + EXPECT_CALL(server_mock_, Get(_, testing::StrCaseEq(L"/favicon.ico"), _)) + .Times(testing::AtMost(2)) + .WillRepeatedly(SendFast("HTTP/1.1 404 Not Found", "")); + std::wstring src_url = server_mock_.Resolve(L"/refresh_src.html"); EXPECT_CALL(server_mock_, Get(_, StrEq(L"/refresh_src.html"), |