diff options
Diffstat (limited to 'chrome_frame/http_negotiate.cc')
-rw-r--r-- | chrome_frame/http_negotiate.cc | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/chrome_frame/http_negotiate.cc b/chrome_frame/http_negotiate.cc index 85701c3..b593f28 100644 --- a/chrome_frame/http_negotiate.cc +++ b/chrome_frame/http_negotiate.cc @@ -32,6 +32,56 @@ const char kLowerCaseUserAgent[] = "user-agent"; // TODO(robertshield): Remove this once we update our SDK version. const int LOCAL_BINDSTATUS_SERVER_MIMETYPEAVAILABLE = 54; +static const int kHttpNegotiateBeginningTransactionIndex = 3; + +BEGIN_VTABLE_PATCHES(IHttpNegotiate) + VTABLE_PATCH_ENTRY(kHttpNegotiateBeginningTransactionIndex, + HttpNegotiatePatch::BeginningTransaction) +END_VTABLE_PATCHES() + +namespace { + +class SimpleBindStatusCallback : public CComObjectRootEx<CComSingleThreadModel>, + public IBindStatusCallback { + public: + BEGIN_COM_MAP(SimpleBindStatusCallback) + COM_INTERFACE_ENTRY(IBindStatusCallback) + END_COM_MAP() + + // IBindStatusCallback implementation + STDMETHOD(OnStartBinding)(DWORD reserved, IBinding* binding) { + return E_NOTIMPL; + } + + STDMETHOD(GetPriority)(LONG* priority) { + return E_NOTIMPL; + } + STDMETHOD(OnLowResource)(DWORD reserved) { + return E_NOTIMPL; + } + + STDMETHOD(OnProgress)(ULONG progress, ULONG max_progress, + ULONG status_code, LPCWSTR status_text) { + return E_NOTIMPL; + } + STDMETHOD(OnStopBinding)(HRESULT result, LPCWSTR error) { + return E_NOTIMPL; + } + + STDMETHOD(GetBindInfo)(DWORD* bind_flags, BINDINFO* bind_info) { + return E_NOTIMPL; + } + + STDMETHOD(OnDataAvailable)(DWORD flags, DWORD size, FORMATETC* formatetc, + STGMEDIUM* storage) { + return E_NOTIMPL; + } + STDMETHOD(OnObjectAvailable)(REFIID iid, IUnknown* object) { + return E_NOTIMPL; + } +}; +} // end namespace + std::string AppendCFUserAgentString(LPCWSTR headers, LPCWSTR additional_headers) { using net::HttpUtil; @@ -107,3 +157,82 @@ std::string ReplaceOrAddUserAgent(LPCWSTR headers, return new_headers; } +HttpNegotiatePatch::HttpNegotiatePatch() { +} + +HttpNegotiatePatch::~HttpNegotiatePatch() { +} + +// static +bool HttpNegotiatePatch::Initialize() { + if (IS_PATCHED(IHttpNegotiate)) { + DLOG(WARNING) << __FUNCTION__ << " called more than once."; + return true; + } + // Use our SimpleBindStatusCallback class as we need a temporary object that + // implements IBindStatusCallback. + CComObjectStackEx<SimpleBindStatusCallback> request; + ScopedComPtr<IBindCtx> bind_ctx; + HRESULT hr = CreateAsyncBindCtx(0, &request, NULL, bind_ctx.Receive()); + DCHECK(SUCCEEDED(hr)) << "CreateAsyncBindCtx"; + if (bind_ctx) { + ScopedComPtr<IUnknown> bscb_holder; + bind_ctx->GetObjectParam(L"_BSCB_Holder_", bscb_holder.Receive()); + if (bscb_holder) { + hr = PatchHttpNegotiate(bscb_holder); + } else { + NOTREACHED() << "Failed to get _BSCB_Holder_"; + hr = E_UNEXPECTED; + } + bind_ctx.Release(); + } + + return SUCCEEDED(hr); +} + +// static +void HttpNegotiatePatch::Uninitialize() { + vtable_patch::UnpatchInterfaceMethods(IHttpNegotiate_PatchInfo); +} + +// static +HRESULT HttpNegotiatePatch::PatchHttpNegotiate(IUnknown* to_patch) { + DCHECK(to_patch); + DCHECK_IS_NOT_PATCHED(IHttpNegotiate); + + ScopedComPtr<IHttpNegotiate> http; + HRESULT hr = http.QueryFrom(to_patch); + if (FAILED(hr)) { + hr = DoQueryService(IID_IHttpNegotiate, to_patch, http.Receive()); + } + + if (http) { + hr = vtable_patch::PatchInterfaceMethods(http, IHttpNegotiate_PatchInfo); + DLOG_IF(ERROR, FAILED(hr)) + << base::StringPrintf("HttpNegotiate patch failed 0x%08X", hr); + } else { + DLOG(WARNING) + << base::StringPrintf("IHttpNegotiate not supported 0x%08X", hr); + } + return hr; +} + +// static +HRESULT HttpNegotiatePatch::BeginningTransaction( + IHttpNegotiate_BeginningTransaction_Fn original, IHttpNegotiate* me, + LPCWSTR url, LPCWSTR headers, DWORD reserved, LPWSTR* additional_headers) { + DVLOG(1) << __FUNCTION__ << " " << url << " headers:\n" << headers; + + HRESULT hr = original(me, url, headers, reserved, additional_headers); + + if (FAILED(hr)) { + DLOG(WARNING) << __FUNCTION__ << " Delegate returned an error"; + return hr; + } + std::string updated(AppendCFUserAgentString(headers, *additional_headers)); + *additional_headers = reinterpret_cast<wchar_t*>(::CoTaskMemRealloc( + *additional_headers, (updated.length() + 1) * sizeof(wchar_t))); + lstrcpyW(*additional_headers, ASCIIToWide(updated).c_str()); + return S_OK; +} + |