summaryrefslogtreecommitdiffstats
path: root/chrome_frame
diff options
context:
space:
mode:
Diffstat (limited to 'chrome_frame')
-rw-r--r--chrome_frame/bho.cc12
-rw-r--r--chrome_frame/http_negotiate.cc129
-rw-r--r--chrome_frame/http_negotiate.h47
-rw-r--r--chrome_frame/test/http_negotiate_unittest.cc106
-rw-r--r--chrome_frame/test/navigation_test.cc9
5 files changed, 300 insertions, 3 deletions
diff --git a/chrome_frame/bho.cc b/chrome_frame/bho.cc
index 7fbc473..f7e8129 100644
--- a/chrome_frame/bho.cc
+++ b/chrome_frame/bho.cc
@@ -319,6 +319,12 @@ bool PatchHelper::InitializeAndPatchProtocolsIfNeeded() {
if (state_ == UNKNOWN) {
g_trans_hooks.InstallHooks();
+ // IE9 sends the short user agent by default. To enable websites to
+ // identify and send content specific to chrome frame we need the
+ // negotiate patch which adds the user agent to outgoing requests.
+ if (GetIEVersion() == IE_9) {
+ HttpNegotiatePatch::Initialize();
+ }
state_ = PATCH_PROTOCOL;
ret = true;
}
@@ -339,9 +345,9 @@ void PatchHelper::PatchBrowserService(IBrowserService* browser_service) {
void PatchHelper::UnpatchIfNeeded() {
if (state_ == PATCH_PROTOCOL) {
g_trans_hooks.RevertHooks();
- } else if (state_ == PATCH_IBROWSER) {
- vtable_patch::UnpatchInterfaceMethods(IBrowserService_PatchInfo);
- MonikerPatch::Uninitialize();
+ if (GetIEVersion() == IE_9) {
+ HttpNegotiatePatch::Uninitialize();
+ }
}
state_ = UNKNOWN;
}
diff --git a/chrome_frame/http_negotiate.cc b/chrome_frame/http_negotiate.cc
index 85701c3..b593f28 100644
--- a/chrome_frame/http_negotiate.cc
+++ b/chrome_frame/http_negotiate.cc
@@ -32,6 +32,56 @@ const char kLowerCaseUserAgent[] = "user-agent";
// TODO(robertshield): Remove this once we update our SDK version.
const int LOCAL_BINDSTATUS_SERVER_MIMETYPEAVAILABLE = 54;
+static const int kHttpNegotiateBeginningTransactionIndex = 3;
+
+BEGIN_VTABLE_PATCHES(IHttpNegotiate)
+ VTABLE_PATCH_ENTRY(kHttpNegotiateBeginningTransactionIndex,
+ HttpNegotiatePatch::BeginningTransaction)
+END_VTABLE_PATCHES()
+
+namespace {
+
+class SimpleBindStatusCallback : public CComObjectRootEx<CComSingleThreadModel>,
+ public IBindStatusCallback {
+ public:
+ BEGIN_COM_MAP(SimpleBindStatusCallback)
+ COM_INTERFACE_ENTRY(IBindStatusCallback)
+ END_COM_MAP()
+
+ // IBindStatusCallback implementation
+ STDMETHOD(OnStartBinding)(DWORD reserved, IBinding* binding) {
+ return E_NOTIMPL;
+ }
+
+ STDMETHOD(GetPriority)(LONG* priority) {
+ return E_NOTIMPL;
+ }
+ STDMETHOD(OnLowResource)(DWORD reserved) {
+ return E_NOTIMPL;
+ }
+
+ STDMETHOD(OnProgress)(ULONG progress, ULONG max_progress,
+ ULONG status_code, LPCWSTR status_text) {
+ return E_NOTIMPL;
+ }
+ STDMETHOD(OnStopBinding)(HRESULT result, LPCWSTR error) {
+ return E_NOTIMPL;
+ }
+
+ STDMETHOD(GetBindInfo)(DWORD* bind_flags, BINDINFO* bind_info) {
+ return E_NOTIMPL;
+ }
+
+ STDMETHOD(OnDataAvailable)(DWORD flags, DWORD size, FORMATETC* formatetc,
+ STGMEDIUM* storage) {
+ return E_NOTIMPL;
+ }
+ STDMETHOD(OnObjectAvailable)(REFIID iid, IUnknown* object) {
+ return E_NOTIMPL;
+ }
+};
+} // end namespace
+
std::string AppendCFUserAgentString(LPCWSTR headers,
LPCWSTR additional_headers) {
using net::HttpUtil;
@@ -107,3 +157,82 @@ std::string ReplaceOrAddUserAgent(LPCWSTR headers,
return new_headers;
}
+HttpNegotiatePatch::HttpNegotiatePatch() {
+}
+
+HttpNegotiatePatch::~HttpNegotiatePatch() {
+}
+
+// static
+bool HttpNegotiatePatch::Initialize() {
+ if (IS_PATCHED(IHttpNegotiate)) {
+ DLOG(WARNING) << __FUNCTION__ << " called more than once.";
+ return true;
+ }
+ // Use our SimpleBindStatusCallback class as we need a temporary object that
+ // implements IBindStatusCallback.
+ CComObjectStackEx<SimpleBindStatusCallback> request;
+ ScopedComPtr<IBindCtx> bind_ctx;
+ HRESULT hr = CreateAsyncBindCtx(0, &request, NULL, bind_ctx.Receive());
+ DCHECK(SUCCEEDED(hr)) << "CreateAsyncBindCtx";
+ if (bind_ctx) {
+ ScopedComPtr<IUnknown> bscb_holder;
+ bind_ctx->GetObjectParam(L"_BSCB_Holder_", bscb_holder.Receive());
+ if (bscb_holder) {
+ hr = PatchHttpNegotiate(bscb_holder);
+ } else {
+ NOTREACHED() << "Failed to get _BSCB_Holder_";
+ hr = E_UNEXPECTED;
+ }
+ bind_ctx.Release();
+ }
+
+ return SUCCEEDED(hr);
+}
+
+// static
+void HttpNegotiatePatch::Uninitialize() {
+ vtable_patch::UnpatchInterfaceMethods(IHttpNegotiate_PatchInfo);
+}
+
+// static
+HRESULT HttpNegotiatePatch::PatchHttpNegotiate(IUnknown* to_patch) {
+ DCHECK(to_patch);
+ DCHECK_IS_NOT_PATCHED(IHttpNegotiate);
+
+ ScopedComPtr<IHttpNegotiate> http;
+ HRESULT hr = http.QueryFrom(to_patch);
+ if (FAILED(hr)) {
+ hr = DoQueryService(IID_IHttpNegotiate, to_patch, http.Receive());
+ }
+
+ if (http) {
+ hr = vtable_patch::PatchInterfaceMethods(http, IHttpNegotiate_PatchInfo);
+ DLOG_IF(ERROR, FAILED(hr))
+ << base::StringPrintf("HttpNegotiate patch failed 0x%08X", hr);
+ } else {
+ DLOG(WARNING)
+ << base::StringPrintf("IHttpNegotiate not supported 0x%08X", hr);
+ }
+ return hr;
+}
+
+// static
+HRESULT HttpNegotiatePatch::BeginningTransaction(
+ IHttpNegotiate_BeginningTransaction_Fn original, IHttpNegotiate* me,
+ LPCWSTR url, LPCWSTR headers, DWORD reserved, LPWSTR* additional_headers) {
+ DVLOG(1) << __FUNCTION__ << " " << url << " headers:\n" << headers;
+
+ HRESULT hr = original(me, url, headers, reserved, additional_headers);
+
+ if (FAILED(hr)) {
+ DLOG(WARNING) << __FUNCTION__ << " Delegate returned an error";
+ return hr;
+ }
+ std::string updated(AppendCFUserAgentString(headers, *additional_headers));
+ *additional_headers = reinterpret_cast<wchar_t*>(::CoTaskMemRealloc(
+ *additional_headers, (updated.length() + 1) * sizeof(wchar_t)));
+ lstrcpyW(*additional_headers, ASCIIToWide(updated).c_str());
+ return S_OK;
+}
+
diff --git a/chrome_frame/http_negotiate.h b/chrome_frame/http_negotiate.h
index 151ad23..eb17c7f 100644
--- a/chrome_frame/http_negotiate.h
+++ b/chrome_frame/http_negotiate.h
@@ -12,6 +12,53 @@
#include "base/basictypes.h"
#include "base/scoped_comptr_win.h"
+// Typedefs for IHttpNegotiate methods.
+typedef HRESULT (STDMETHODCALLTYPE* IHttpNegotiate_BeginningTransaction_Fn)(
+ IHttpNegotiate* me, LPCWSTR url, LPCWSTR headers, DWORD reserved,
+ LPWSTR* additional_headers);
+typedef HRESULT (STDMETHODCALLTYPE* IHttpNegotiate_OnResponse_Fn)(
+ IHttpNegotiate* me, DWORD response_code, LPCWSTR response_header,
+ LPCWSTR request_header, LPWSTR* additional_request_headers);
+
+// Typedefs for IBindStatusCallback methods.
+typedef HRESULT (STDMETHODCALLTYPE* IBindStatusCallback_StartBinding_Fn)(
+ IBindStatusCallback* me, DWORD reserved, IBinding *binding);
+
+// Typedefs for IInternetProtocolSink methods.
+typedef HRESULT (STDMETHODCALLTYPE* IInternetProtocolSink_ReportProgress_Fn)(
+ IInternetProtocolSink* me, ULONG status_code, LPCWSTR status_text);
+
+// Patches methods of urlmon's IHttpNegotiate implementation for the purposes
+// of adding to the http user agent header.
+
+// Also patches one of the IBindStatusCallback implementations in urlmon to pick
+// up an IBinding during the StartBinding call. The IBinding implementor then
+// gets a patch applied to its IInternetProtocolSink's ReportProgress method.
+// The patched is there so that the reporting of the MIME type to the IBinding
+// implementor can be changed if an X-Chrome-Frame HTTP header is present
+// in the response headers. If anyone can suggest a more straightforward way of
+// doing this, I would be eternally grateful.
+class HttpNegotiatePatch {
+ // class is not to be instantiated atm.
+ HttpNegotiatePatch();
+ ~HttpNegotiatePatch();
+
+ public:
+ static bool Initialize();
+ static void Uninitialize();
+
+ // IHttpNegotiate patch methods
+ static STDMETHODIMP BeginningTransaction(
+ IHttpNegotiate_BeginningTransaction_Fn original, IHttpNegotiate* me,
+ LPCWSTR url, LPCWSTR headers, DWORD reserved, LPWSTR* additional_headers);
+
+ protected:
+ static HRESULT PatchHttpNegotiate(IUnknown* to_patch);
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(HttpNegotiatePatch);
+};
+
// From the latest urlmon.h. Symbol name prepended with LOCAL_ to
// avoid conflict (and therefore build errors) for those building with
// a newer Windows SDK.
diff --git a/chrome_frame/test/http_negotiate_unittest.cc b/chrome_frame/test/http_negotiate_unittest.cc
index 857197e..c56034b 100644
--- a/chrome_frame/test/http_negotiate_unittest.cc
+++ b/chrome_frame/test/http_negotiate_unittest.cc
@@ -16,6 +16,112 @@
#include "gtest/gtest.h"
#include "gmock/gmock.h"
+class HttpNegotiateTest : public testing::Test {
+ protected:
+ HttpNegotiateTest() {
+ }
+};
+
+class TestHttpNegotiate
+ : public CComObjectRootEx<CComMultiThreadModel>,
+ public IHttpNegotiate {
+ public:
+ TestHttpNegotiate()
+ : beginning_transaction_ret_(S_OK), additional_headers_(NULL) {
+ }
+
+BEGIN_COM_MAP(TestHttpNegotiate)
+ COM_INTERFACE_ENTRY(IHttpNegotiate)
+END_COM_MAP()
+ STDMETHOD(BeginningTransaction)(LPCWSTR url, LPCWSTR headers, // NOLINT
+ DWORD reserved, // NOLINT
+ LPWSTR* additional_headers) { // NOLINT
+ if (additional_headers_) {
+ int len = lstrlenW(additional_headers_);
+ len++;
+ *additional_headers = reinterpret_cast<wchar_t*>(
+ ::CoTaskMemAlloc(len * sizeof(wchar_t)));
+ lstrcpyW(*additional_headers, additional_headers_);
+ }
+ return beginning_transaction_ret_;
+ }
+
+ STDMETHOD(OnResponse)(DWORD response_code, LPCWSTR response_header,
+ LPCWSTR request_header,
+ LPWSTR* additional_request_headers) {
+ return S_OK;
+ }
+
+ HRESULT beginning_transaction_ret_;
+ const wchar_t* additional_headers_;
+};
+
+TEST_F(HttpNegotiateTest, BeginningTransaction) {
+ static const int kBeginningTransactionIndex = 3;
+ CComObjectStackEx<TestHttpNegotiate> test_http;
+ IHttpNegotiate_BeginningTransaction_Fn original =
+ reinterpret_cast<IHttpNegotiate_BeginningTransaction_Fn>(
+ (*reinterpret_cast<void***>(
+ static_cast<IHttpNegotiate*>(
+ &test_http)))[kBeginningTransactionIndex]);
+
+ std::wstring cf_ua(
+ ASCIIToWide(http_utils::GetDefaultUserAgentHeaderWithCFTag()));
+ std::wstring cf_tag(
+ ASCIIToWide(http_utils::GetChromeFrameUserAgent()));
+
+ EXPECT_NE(std::wstring::npos, cf_ua.find(cf_tag));
+
+ struct TestCase {
+ const std::wstring original_headers_;
+ const std::wstring delegate_additional_;
+ const std::wstring expected_additional_;
+ HRESULT delegate_return_value_;
+ } test_cases[] = {
+ { L"Accept: */*\r\n",
+ L"",
+ cf_ua + L"\r\n",
+ S_OK },
+ { L"Accept: */*\r\n",
+ L"",
+ L"",
+ E_OUTOFMEMORY },
+ { L"",
+ L"Accept: */*\r\n",
+ L"Accept: */*\r\n" + cf_ua + L"\r\n",
+ S_OK },
+ { L"User-Agent: Bingo/1.0\r\n",
+ L"",
+ L"User-Agent: Bingo/1.0 " + cf_tag + L"\r\n",
+ S_OK },
+ { L"User-Agent: NotMe/1.0\r\n",
+ L"User-Agent: MeMeMe/1.0\r\n",
+ L"User-Agent: MeMeMe/1.0 " + cf_tag + L"\r\n",
+ S_OK },
+ { L"",
+ L"User-Agent: MeMeMe/1.0\r\n",
+ L"User-Agent: MeMeMe/1.0 " + cf_tag + L"\r\n",
+ S_OK },
+ };
+
+ for (int i = 0; i < arraysize(test_cases); ++i) {
+ TestCase& test = test_cases[i];
+ wchar_t* additional = NULL;
+ test_http.beginning_transaction_ret_ = test.delegate_return_value_;
+ test_http.additional_headers_ = test.delegate_additional_.c_str();
+ HttpNegotiatePatch::BeginningTransaction(original, &test_http,
+ L"http://www.google.com", test.original_headers_.c_str(), 0,
+ &additional);
+ EXPECT_TRUE(additional != NULL);
+
+ if (additional) {
+ // Check against the expected additional headers.
+ EXPECT_EQ(test.expected_additional_, std::wstring(additional));
+ ::CoTaskMemFree(additional);
+ }
+ }
+}
+
class TestInternetProtocolSink
: public CComObjectRootEx<CComMultiThreadModel>,
public IInternetProtocolSink {
diff --git a/chrome_frame/test/navigation_test.cc b/chrome_frame/test/navigation_test.cc
index 7e157ee..ff47ba5 100644
--- a/chrome_frame/test/navigation_test.cc
+++ b/chrome_frame/test/navigation_test.cc
@@ -1133,8 +1133,17 @@ TEST_P(FullTabNavigationTest, RefreshContentsUATest) {
bool in_cf = GetParam().invokes_cf();
if (in_cf) {
headers.append("X-UA-Compatible: chrome=1\r\n");
+ } else {
+ if (GetInstalledIEVersion() == IE_9) {
+ LOG(ERROR) << "Test disabled for IE9";
+ return;
+ }
}
+ EXPECT_CALL(server_mock_, Get(_, testing::StrCaseEq(L"/favicon.ico"), _))
+ .Times(testing::AtMost(2))
+ .WillRepeatedly(SendFast("HTTP/1.1 404 Not Found", ""));
+
std::wstring src_url = server_mock_.Resolve(L"/refresh_src.html");
EXPECT_CALL(server_mock_, Get(_, StrEq(L"/refresh_src.html"),