diff options
author | cbentzel@chromium.org <cbentzel@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2010-02-25 16:05:34 +0000 |
---|---|---|
committer | cbentzel@chromium.org <cbentzel@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2010-02-25 16:05:34 +0000 |
commit | 7c46fa57a345c80dbc812b450a9a2996ab6ace88 (patch) | |
tree | e6d3352695b9e7af5d58804bd065e3703b32c92c | |
parent | 09bde0c8d97f44a3a6476ac81d84697f05ecff6f (diff) | |
download | chromium_src-7c46fa57a345c80dbc812b450a9a2996ab6ace88.zip chromium_src-7c46fa57a345c80dbc812b450a9a2996ab6ace88.tar.gz chromium_src-7c46fa57a345c80dbc812b450a9a2996ab6ace88.tar.bz2 |
Added SSPILibrary interface so unit tests can mock SSPI calls.
BUG=None
TEST=net_unittests.exe --gtest_filter="*HttpAuthSSPI*"
Review URL: http://codereview.chromium.org/650164
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@40021 0039d316-1c4b-4281-b951-d872f2087c98
-rw-r--r-- | net/http/http_auth_handler_negotiate.h | 15 | ||||
-rw-r--r-- | net/http/http_auth_handler_negotiate_win.cc | 15 | ||||
-rw-r--r-- | net/http/http_auth_handler_ntlm.h | 14 | ||||
-rw-r--r-- | net/http/http_auth_handler_ntlm_win.cc | 16 | ||||
-rw-r--r-- | net/http/http_auth_sspi_win.cc | 126 | ||||
-rw-r--r-- | net/http/http_auth_sspi_win.h | 61 | ||||
-rw-r--r-- | net/http/http_auth_sspi_win_unittest.cc | 186 |
7 files changed, 386 insertions, 47 deletions
diff --git a/net/http/http_auth_handler_negotiate.h b/net/http/http_auth_handler_negotiate.h index 73b1240..4577d9c 100644 --- a/net/http/http_auth_handler_negotiate.h +++ b/net/http/http_auth_handler_negotiate.h @@ -34,16 +34,29 @@ class HttpAuthHandlerNegotiate : public HttpAuthHandler { HttpAuth::Target target, const GURL& origin, scoped_refptr<HttpAuthHandler>* handler); + +#if defined(OS_WIN) + // Set the SSPILibrary to use. Typically the only callers which need to + // use this are unit tests which pass in a mocked-out version of the + // SSPI library. + // The caller is responsible for managing the lifetime of |*sspi_library|, + // and the lifetime must exceed that of this Factory object and all + // HttpAuthHandler's that this Factory object creates. + void set_sspi_library(SSPILibrary* sspi_library) { + sspi_library_ = sspi_library; + } +#endif // defined(OS_WIN) private: #if defined(OS_WIN) ULONG max_token_length_; bool first_creation_; bool is_unsupported_; + SSPILibrary* sspi_library_; #endif // defined(OS_WIN) }; #if defined(OS_WIN) - explicit HttpAuthHandlerNegotiate(ULONG max_token_length); + HttpAuthHandlerNegotiate(SSPILibrary* sspi_library, ULONG max_token_length); #else HttpAuthHandlerNegotiate(); #endif diff --git a/net/http/http_auth_handler_negotiate_win.cc b/net/http/http_auth_handler_negotiate_win.cc index bbd62a4..ffa48ef9 100644 --- a/net/http/http_auth_handler_negotiate_win.cc +++ b/net/http/http_auth_handler_negotiate_win.cc @@ -8,8 +8,9 @@ namespace net { -HttpAuthHandlerNegotiate::HttpAuthHandlerNegotiate(ULONG max_token_length) - : auth_sspi_("Negotiate", NEGOSSP_NAME, max_token_length) { +HttpAuthHandlerNegotiate::HttpAuthHandlerNegotiate(SSPILibrary* library, + ULONG max_token_length) + : auth_sspi_(library, "Negotiate", NEGOSSP_NAME, max_token_length) { } HttpAuthHandlerNegotiate::~HttpAuthHandlerNegotiate() { @@ -77,7 +78,8 @@ int HttpAuthHandlerNegotiate::GenerateDefaultAuthToken( HttpAuthHandlerNegotiate::Factory::Factory() : max_token_length_(0), first_creation_(true), - is_unsupported_(false) { + is_unsupported_(false), + sspi_library_(SSPILibrary::GetDefault()) { } HttpAuthHandlerNegotiate::Factory::~Factory() { @@ -90,19 +92,18 @@ int HttpAuthHandlerNegotiate::Factory::CreateAuthHandler( scoped_refptr<HttpAuthHandler>* handler) { if (is_unsupported_) return ERR_UNSUPPORTED_AUTH_SCHEME; - if (max_token_length_ == 0) { - int rv = DetermineMaxTokenLength(NEGOSSP_NAME, &max_token_length_); + int rv = DetermineMaxTokenLength(sspi_library_, NEGOSSP_NAME, + &max_token_length_); if (rv == ERR_UNSUPPORTED_AUTH_SCHEME) is_unsupported_ = true; if (rv != OK) return rv; } - // TODO(cbentzel): Move towards model of parsing in the factory // method and only constructing when valid. scoped_refptr<HttpAuthHandler> tmp_handler( - new HttpAuthHandlerNegotiate(max_token_length_)); + new HttpAuthHandlerNegotiate(sspi_library_, max_token_length_)); if (!tmp_handler->InitFromChallenge(challenge, target, origin)) return ERR_INVALID_RESPONSE; handler->swap(tmp_handler); diff --git a/net/http/http_auth_handler_ntlm.h b/net/http/http_auth_handler_ntlm.h index 9e86e1e..aebb5c4 100644 --- a/net/http/http_auth_handler_ntlm.h +++ b/net/http/http_auth_handler_ntlm.h @@ -43,11 +43,23 @@ class HttpAuthHandlerNTLM : public HttpAuthHandler { HttpAuth::Target target, const GURL& origin, scoped_refptr<HttpAuthHandler>* handler); +#if defined(NTLM_SSPI) + // Set the SSPILibrary to use. Typically the only callers which need to + // use this are unit tests which pass in a mocked-out version of the + // SSPI library. + // The caller is responsible for managing the lifetime of |*sspi_library|, + // and the lifetime must exceed that of this Factory object and all + // HttpAuthHandler's that this Factory object creates. + void set_sspi_library(SSPILibrary* sspi_library) { + sspi_library_ = sspi_library; + } +#endif // defined(NTLM_SSPI) private: #if defined(NTLM_SSPI) ULONG max_token_length_; bool first_creation_; bool is_unsupported_; + SSPILibrary* sspi_library_; #endif // defined(NTLM_SSPI) }; @@ -84,7 +96,7 @@ class HttpAuthHandlerNTLM : public HttpAuthHandler { HttpAuthHandlerNTLM(); #endif #if defined(NTLM_SSPI) - HttpAuthHandlerNTLM(ULONG max_token_length); + HttpAuthHandlerNTLM(SSPILibrary* sspi_library, ULONG max_token_length); #endif virtual bool NeedsIdentity(); diff --git a/net/http/http_auth_handler_ntlm_win.cc b/net/http/http_auth_handler_ntlm_win.cc index cf3b448..989f1db 100644 --- a/net/http/http_auth_handler_ntlm_win.cc +++ b/net/http/http_auth_handler_ntlm_win.cc @@ -19,8 +19,9 @@ namespace net { -HttpAuthHandlerNTLM::HttpAuthHandlerNTLM(ULONG max_token_length) : - auth_sspi_("NTLM", NTLMSP_NAME, max_token_length) { +HttpAuthHandlerNTLM::HttpAuthHandlerNTLM(SSPILibrary* sspi_library, + ULONG max_token_length) : + auth_sspi_(sspi_library, "NTLM", NTLMSP_NAME, max_token_length) { } HttpAuthHandlerNTLM::~HttpAuthHandlerNTLM() { @@ -64,7 +65,8 @@ int HttpAuthHandlerNTLM::GenerateDefaultAuthToken( HttpAuthHandlerNTLM::Factory::Factory() : max_token_length_(0), first_creation_(true), - is_unsupported_(false) { + is_unsupported_(false), + sspi_library_(SSPILibrary::GetDefault()) { } HttpAuthHandlerNTLM::Factory::~Factory() { @@ -77,19 +79,18 @@ int HttpAuthHandlerNTLM::Factory::CreateAuthHandler( scoped_refptr<HttpAuthHandler>* handler) { if (is_unsupported_) return ERR_UNSUPPORTED_AUTH_SCHEME; - if (max_token_length_ == 0) { - int rv = DetermineMaxTokenLength(NTLMSP_NAME, &max_token_length_); + int rv = DetermineMaxTokenLength(sspi_library_, NTLMSP_NAME, + &max_token_length_); if (rv == ERR_UNSUPPORTED_AUTH_SCHEME) is_unsupported_ = true; if (rv != OK) return rv; } - // TODO(cbentzel): Move towards model of parsing in the factory // method and only constructing when valid. scoped_refptr<HttpAuthHandler> tmp_handler( - new HttpAuthHandlerNTLM(max_token_length_)); + new HttpAuthHandlerNTLM(sspi_library_, max_token_length_)); if (!tmp_handler->InitFromChallenge(challenge, target, origin)) return ERR_INVALID_RESPONSE; handler->swap(tmp_handler); @@ -97,4 +98,3 @@ int HttpAuthHandlerNTLM::Factory::CreateAuthHandler( } } // namespace net - diff --git a/net/http/http_auth_sspi_win.cc b/net/http/http_auth_sspi_win.cc index 39d87af..5349e76 100644 --- a/net/http/http_auth_sspi_win.cc +++ b/net/http/http_auth_sspi_win.cc @@ -9,12 +9,14 @@ #include "base/base64.h" #include "base/logging.h" +#include "base/singleton.h" #include "base/string_util.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" #include "net/http/http_auth.h" namespace net { + namespace { int MapAcquireCredentialsStatusToError(SECURITY_STATUS status, @@ -40,7 +42,8 @@ int MapAcquireCredentialsStatusToError(SECURITY_STATUS status, } } -int AcquireCredentials(const SEC_WCHAR* package, +int AcquireCredentials(SSPILibrary* library, + const SEC_WCHAR* package, const std::wstring& domain, const std::wstring& user, const std::wstring& password, @@ -60,7 +63,7 @@ int AcquireCredentials(const SEC_WCHAR* package, TimeStamp expiry; // Pass the username/password to get the credentials handle. - SECURITY_STATUS status = AcquireCredentialsHandle( + SECURITY_STATUS status = library->AcquireCredentialsHandle( NULL, // pszPrincipal const_cast<SEC_WCHAR*>(package), // pszPackage SECPKG_CRED_OUTBOUND, // fCredentialUse @@ -74,14 +77,15 @@ int AcquireCredentials(const SEC_WCHAR* package, return MapAcquireCredentialsStatusToError(status, package); } -int AcquireDefaultCredentials(const SEC_WCHAR* package, CredHandle* cred) { +int AcquireDefaultCredentials(SSPILibrary* library, const SEC_WCHAR* package, + CredHandle* cred) { TimeStamp expiry; // Pass the username/password to get the credentials handle. // Note: Since the 5th argument is NULL, it uses the default // cached credentials for the logged in user, which can be used // for a single sign-on. - SECURITY_STATUS status = AcquireCredentialsHandle( + SECURITY_STATUS status = library->AcquireCredentialsHandle( NULL, // pszPrincipal const_cast<SEC_WCHAR*>(package), // pszPackage SECPKG_CRED_OUTBOUND, // fCredentialUse @@ -97,12 +101,15 @@ int AcquireDefaultCredentials(const SEC_WCHAR* package, CredHandle* cred) { } // anonymous namespace -HttpAuthSSPI::HttpAuthSSPI(const std::string& scheme, +HttpAuthSSPI::HttpAuthSSPI(SSPILibrary* library, + const std::string& scheme, SEC_WCHAR* security_package, ULONG max_token_length) - : scheme_(scheme), + : library_(library), + scheme_(scheme), security_package_(security_package), max_token_length_(max_token_length) { + DCHECK(library_); SecInvalidateHandle(&cred_); SecInvalidateHandle(&ctxt_); } @@ -110,7 +117,7 @@ HttpAuthSSPI::HttpAuthSSPI(const std::string& scheme, HttpAuthSSPI::~HttpAuthSSPI() { ResetSecurityContext(); if (SecIsValidHandle(&cred_)) { - FreeCredentialsHandle(&cred_); + library_->FreeCredentialsHandle(&cred_); SecInvalidateHandle(&cred_); } } @@ -125,7 +132,7 @@ bool HttpAuthSSPI::IsFinalRound() const { void HttpAuthSSPI::ResetSecurityContext() { if (SecIsValidHandle(&ctxt_)) { - DeleteSecurityContext(&ctxt_); + library_->DeleteSecurityContext(&ctxt_); SecInvalidateHandle(&ctxt_); } } @@ -202,11 +209,12 @@ int HttpAuthSSPI::OnFirstRound(const std::wstring* username, std::wstring domain; std::wstring user; SplitDomainAndUser(*username, &domain, &user); - rv = AcquireCredentials(security_package_, domain, user, *password, &cred_); + rv = AcquireCredentials(library_, security_package_, domain, + user, *password, &cred_); if (rv != OK) return rv; } else { - rv = AcquireDefaultCredentials(security_package_, &cred_); + rv = AcquireDefaultCredentials(library_, security_package_, &cred_); if (rv != OK) return rv; } @@ -268,18 +276,19 @@ int HttpAuthSSPI::GetNextSecurityToken( wchar_t* target_name = const_cast<wchar_t*>(target.c_str()); // This returns a token that is passed to the remote server. - status = InitializeSecurityContext(&cred_, // phCredential - ctxt_ptr, // phContext - target_name, // pszTargetName - 0, // fContextReq - 0, // Reserved1 (must be 0) - SECURITY_NATIVE_DREP, // TargetDataRep - in_buffer_desc_ptr, // pInput - 0, // Reserved2 (must be 0) - &ctxt_, // phNewContext - &out_buffer_desc, // pOutput - &ctxt_attr, // pfContextAttr - &expiry); // ptsExpiry + status = library_->InitializeSecurityContext( + &cred_, // phCredential + ctxt_ptr, // phContext + target_name, // pszTargetName + 0, // fContextReq + 0, // Reserved1 (must be 0) + SECURITY_NATIVE_DREP, // TargetDataRep + in_buffer_desc_ptr, // pInput + 0, // Reserved2 (must be 0) + &ctxt_, // phNewContext + &out_buffer_desc, // pOutput + &ctxt_attr, // pfContextAttr + &expiry); // ptsExpiry // On success, the function returns SEC_I_CONTINUE_NEEDED on the first call // and SEC_E_OK on the second call. On failure, the function returns an // error code. @@ -314,10 +323,13 @@ void SplitDomainAndUser(const std::wstring& combined, } } -int DetermineMaxTokenLength(const std::wstring& package, +int DetermineMaxTokenLength(SSPILibrary* library, + const std::wstring& package, ULONG* max_token_length) { + DCHECK(library); + DCHECK(max_token_length); PSecPkgInfo pkg_info = NULL; - SECURITY_STATUS status = QuerySecurityPackageInfo( + SECURITY_STATUS status = library->QuerySecurityPackageInfo( const_cast<wchar_t *>(package.c_str()), &pkg_info); if (status != SEC_E_OK) { // The documentation at @@ -333,7 +345,7 @@ int DetermineMaxTokenLength(const std::wstring& package, return ERR_UNEXPECTED; } int token_length = pkg_info->cbMaxToken; - status = FreeContextBuffer(pkg_info); + status = library->FreeContextBuffer(pkg_info); if (status != SEC_E_OK) { // The documentation at // http://msdn.microsoft.com/en-us/library/aa375416(VS.85).aspx @@ -348,4 +360,68 @@ int DetermineMaxTokenLength(const std::wstring& package, return OK; } +class SSPILibraryDefault : public SSPILibrary { + public: + SSPILibraryDefault() {} + virtual ~SSPILibraryDefault() {} + + virtual SECURITY_STATUS AcquireCredentialsHandle(LPWSTR pszPrincipal, + LPWSTR pszPackage, + unsigned long fCredentialUse, + void* pvLogonId, + void* pvAuthData, + SEC_GET_KEY_FN pGetKeyFn, + void* pvGetKeyArgument, + PCredHandle phCredential, + PTimeStamp ptsExpiry) { + return ::AcquireCredentialsHandle(pszPrincipal, pszPackage, fCredentialUse, + pvLogonId, pvAuthData, pGetKeyFn, + pvGetKeyArgument, phCredential, + ptsExpiry); + } + + virtual SECURITY_STATUS InitializeSecurityContext(PCredHandle phCredential, + PCtxtHandle phContext, + SEC_WCHAR* pszTargetName, + unsigned long fContextReq, + unsigned long Reserved1, + unsigned long TargetDataRep, + PSecBufferDesc pInput, + unsigned long Reserved2, + PCtxtHandle phNewContext, + PSecBufferDesc pOutput, + unsigned long* contextAttr, + PTimeStamp ptsExpiry) { + return ::InitializeSecurityContext(phCredential, phContext, pszTargetName, + fContextReq, Reserved1, TargetDataRep, + pInput, Reserved2, phNewContext, pOutput, + contextAttr, ptsExpiry); + } + + virtual SECURITY_STATUS QuerySecurityPackageInfo(LPWSTR pszPackageName, + PSecPkgInfoW *pkgInfo) { + return ::QuerySecurityPackageInfo(pszPackageName, pkgInfo); + } + + virtual SECURITY_STATUS FreeCredentialsHandle(PCredHandle phCredential) { + return ::FreeCredentialsHandle(phCredential); + } + + virtual SECURITY_STATUS DeleteSecurityContext(PCtxtHandle phContext) { + return ::DeleteSecurityContext(phContext); + } + + virtual SECURITY_STATUS FreeContextBuffer(PVOID pvContextBuffer) { + return ::FreeContextBuffer(pvContextBuffer); + } + + private: + friend struct DefaultSingletonTraits<SSPILibraryDefault>; +}; + +// static +SSPILibrary* SSPILibrary::GetDefault() { + return Singleton<SSPILibraryDefault>::get(); +} + } // namespace net diff --git a/net/http/http_auth_sspi_win.h b/net/http/http_auth_sspi_win.h index 186c65f..c925920 100644 --- a/net/http/http_auth_sspi_win.h +++ b/net/http/http_auth_sspi_win.h @@ -25,9 +25,58 @@ namespace net { class HttpRequestInfo; class ProxyInfo; +// SSPILibrary is introduced so unit tests can mock the calls to Windows' SSPI +// implementation. The default implementation simply passes the arguments on to +// the SSPI implementation provided by Secur32.dll. +// NOTE(cbentzel): I considered replacing the Secur32.dll with a mock DLL, but +// decided that it wasn't worth the effort as this is unlikely to be performance +// sensitive code. +class SSPILibrary { + public: + virtual ~SSPILibrary() {}; + + virtual SECURITY_STATUS AcquireCredentialsHandle(LPWSTR pszPrincipal, + LPWSTR pszPackage, + unsigned long fCredentialUse, + void* pvLogonId, + void* pvAuthData, + SEC_GET_KEY_FN pGetKeyFn, + void* pvGetKeyArgument, + PCredHandle phCredential, + PTimeStamp ptsExpiry) = 0; + + virtual SECURITY_STATUS InitializeSecurityContext(PCredHandle phCredential, + PCtxtHandle phContext, + SEC_WCHAR* pszTargetName, + unsigned long fContextReq, + unsigned long Reserved1, + unsigned long TargetDataRep, + PSecBufferDesc pInput, + unsigned long Reserved2, + PCtxtHandle phNewContext, + PSecBufferDesc pOutput, + unsigned long* contextAttr, + PTimeStamp ptsExpiry) = 0; + + virtual SECURITY_STATUS QuerySecurityPackageInfo(LPWSTR pszPackageName, + PSecPkgInfoW *pkgInfo) = 0; + + virtual SECURITY_STATUS FreeCredentialsHandle(PCredHandle phCredential) = 0; + + virtual SECURITY_STATUS DeleteSecurityContext(PCtxtHandle phContext) = 0; + + virtual SECURITY_STATUS FreeContextBuffer(PVOID pvContextBuffer) = 0; + + // Get the default SSPILibrary instance, which simply acts as a passthrough + // to the Windows SSPI implementation. The object returned is a singleton + // instance, and the caller should not delete it. + static SSPILibrary* GetDefault(); +}; + class HttpAuthSSPI { public: - HttpAuthSSPI(const std::string& scheme, + HttpAuthSSPI(SSPILibrary* sspi_library, + const std::string& scheme, SEC_WCHAR* security_package, ULONG max_token_length); ~HttpAuthSSPI(); @@ -62,6 +111,8 @@ class HttpAuthSSPI { int* out_token_len); void ResetSecurityContext(); + + SSPILibrary* library_; std::string scheme_; SEC_WCHAR* security_package_; std::string decoded_server_auth_token_; @@ -82,11 +133,11 @@ void SplitDomainAndUser(const std::wstring& combined, // Determines the maximum token length in bytes for a particular SSPI package. // +// |library| and |max_token_length| must be non-NULL pointers to valid objects. +// // If the return value is OK, |*max_token_length| contains the maximum token // length in bytes. // -// If the return value is ERR_INVALID_ARGUMENT, |max_token_length| is NULL. -// // If the return value is ERR_UNSUPPORTED_AUTH_SCHEME, |package| is not an // known SSPI authentication scheme on this system. |*max_token_length| is not // changed. @@ -94,9 +145,11 @@ void SplitDomainAndUser(const std::wstring& combined, // If the return value is ERR_UNEXPECTED, there was an unanticipated problem // in the underlying SSPI call. The details are logged, and |*max_token_length| // is not changed. -int DetermineMaxTokenLength(const std::wstring& package, +int DetermineMaxTokenLength(SSPILibrary* library, + const std::wstring& package, ULONG* max_token_length); } // namespace net + #endif // NET_HTTP_HTTP_AUTH_SSPI_WIN_H_ diff --git a/net/http/http_auth_sspi_win_unittest.cc b/net/http/http_auth_sspi_win_unittest.cc index b421ca9..7690664 100644 --- a/net/http/http_auth_sspi_win_unittest.cc +++ b/net/http/http_auth_sspi_win_unittest.cc @@ -2,12 +2,18 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <list> +#include <set> + #include "base/basictypes.h" +#include "net/base/net_errors.h" #include "net/http/http_auth_sspi_win.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { +namespace { + void MatchDomainUserAfterSplit(const std::wstring& combined, const std::wstring& expected_domain, const std::wstring& expected_user) { @@ -18,9 +24,187 @@ void MatchDomainUserAfterSplit(const std::wstring& combined, EXPECT_EQ(expected_user, actual_user); } -TEST(HttpAuthHandlerSspiWinTest, SplitUserAndDomain) { +class MockSSPILibrary : public SSPILibrary { + public: + MockSSPILibrary() {} + virtual ~MockSSPILibrary() {} + + virtual SECURITY_STATUS AcquireCredentialsHandle(LPWSTR pszPrincipal, + LPWSTR pszPackage, + unsigned long fCredentialUse, + void* pvLogonId, + void* pvAuthData, + SEC_GET_KEY_FN pGetKeyFn, + void* pvGetKeyArgument, + PCredHandle phCredential, + PTimeStamp ptsExpiry) { + ADD_FAILURE(); + return ERROR_CALL_NOT_IMPLEMENTED; + } + + virtual SECURITY_STATUS InitializeSecurityContext(PCredHandle phCredential, + PCtxtHandle phContext, + SEC_WCHAR* pszTargetName, + unsigned long fContextReq, + unsigned long Reserved1, + unsigned long TargetDataRep, + PSecBufferDesc pInput, + unsigned long Reserved2, + PCtxtHandle phNewContext, + PSecBufferDesc pOutput, + unsigned long* contextAttr, + PTimeStamp ptsExpiry) { + ADD_FAILURE(); + return ERROR_CALL_NOT_IMPLEMENTED; + } + + virtual SECURITY_STATUS QuerySecurityPackageInfo(LPWSTR pszPackageName, + PSecPkgInfoW *pkgInfo) { + ADD_FAILURE(); + return ERROR_CALL_NOT_IMPLEMENTED; + } + + virtual SECURITY_STATUS FreeCredentialsHandle(PCredHandle phCredential) { + ADD_FAILURE(); + return ERROR_CALL_NOT_IMPLEMENTED; + } + + virtual SECURITY_STATUS DeleteSecurityContext(PCtxtHandle phContext) { + ADD_FAILURE(); + return ERROR_CALL_NOT_IMPLEMENTED; + } + + virtual SECURITY_STATUS FreeContextBuffer(PVOID pvContextBuffer) { + ADD_FAILURE(); + return ERROR_CALL_NOT_IMPLEMENTED; + } +}; + +class ExpectedPackageQuerySSPILibrary : public MockSSPILibrary { + public: + ExpectedPackageQuerySSPILibrary() {} + virtual ~ExpectedPackageQuerySSPILibrary() { + EXPECT_TRUE(expected_package_queries_.empty()); + EXPECT_TRUE(expected_freed_packages_.empty()); + } + + // Establishes an expectation for a |QuerySecurityPackageInfo()| call. + // + // Each expectation established by |ExpectSecurityQueryPackageInfo()| must be + // matched by a call to |QuerySecurityPackageInfo()| during the lifetime of + // the MockSSPILibrary. The |expected_package| argument must equal the + // |*pszPackageName| argument to |QuerySecurityPackageInfo()| for there to be + // a match. The expectations also establish an explicit ordering. + // + // For example, this sequence will be successful. + // ExpectedPackageQuerySSPILibrary lib; + // lib.ExpectQuerySecurityPackageInfo(L"NTLM", ...) + // lib.ExpectQuerySecurityPackageInfo(L"Negotiate", ...) + // lib.QuerySecurityPackageInfo(L"NTLM", ...) + // lib.QuerySecurityPackageInfo(L"Negotiate", ...) + // + // This sequence will fail since the queries do not occur in the order + // established by the expectations. + // ExpectedPackageQuerySSPILibrary lib; + // lib.ExpectQuerySecurityPackageInfo(L"NTLM", ...) + // lib.ExpectQuerySecurityPackageInfo(L"Negotiate", ...) + // lib.QuerySecurityPackageInfo(L"Negotiate", ...) + // lib.QuerySecurityPackageInfo(L"NTLM", ...) + // + // This sequence will fail because there were not enough queries. + // ExpectedPackageQuerySSPILibrary lib; + // lib.ExpectQuerySecurityPackageInfo(L"NTLM", ...) + // lib.ExpectQuerySecurityPackageInfo(L"Negotiate", ...) + // lib.QuerySecurityPackageInfo(L"NTLM", ...) + // + // |response_code| is used as the return value for + // |QuerySecurityPackageInfo()|. If |response_code| is SEC_E_OK, + // an expectation is also set for a call to |FreeContextBuffer()| after + // the matching |QuerySecurityPackageInfo()| is called. + // + // |package_info| is assigned to |*pkgInfo| in |QuerySecurityPackageInfo|. + // The lifetime of |*package_info| should last at least until the matching + // |QuerySecurityPackageInfo()| is called. + void ExpectQuerySecurityPackageInfo(const std::wstring& expected_package, + SECURITY_STATUS response_code, + PSecPkgInfoW package_info) { + PackageQuery package_query = {expected_package, response_code, + package_info}; + expected_package_queries_.push_back(package_query); + } + + // Queries security package information. This must be an expected call, + // see |ExpectQuerySecurityPackageInfo()| for more information. + virtual SECURITY_STATUS QuerySecurityPackageInfo(LPWSTR pszPackageName, + PSecPkgInfoW* pkgInfo) { + EXPECT_TRUE(!expected_package_queries_.empty()); + PackageQuery package_query = expected_package_queries_.front(); + expected_package_queries_.pop_front(); + std::wstring actual_package(pszPackageName); + EXPECT_EQ(package_query.expected_package, actual_package); + *pkgInfo = package_query.package_info; + if (package_query.response_code == SEC_E_OK) + expected_freed_packages_.insert(package_query.package_info); + return package_query.response_code; + } + + // Frees the context buffer. This should be called after a successful call + // of |QuerySecurityPackageInfo()|. + virtual SECURITY_STATUS FreeContextBuffer(PVOID pvContextBuffer) { + PSecPkgInfoW package_info = static_cast<PSecPkgInfoW>(pvContextBuffer); + std::set<PSecPkgInfoW>::iterator it = expected_freed_packages_.find( + package_info); + EXPECT_TRUE(it != expected_freed_packages_.end()); + expected_freed_packages_.erase(it); + return SEC_E_OK; + } + + private: + struct PackageQuery { + std::wstring expected_package; + SECURITY_STATUS response_code; + PSecPkgInfoW package_info; + }; + + // expected_package_queries contains an ordered list of expected + // |QuerySecurityPackageInfo()| calls and the return values for those + // calls. + std::list<PackageQuery> expected_package_queries_; + + // Set of packages which should be freed. + std::set<PSecPkgInfoW> expected_freed_packages_; +}; + +} // namespace + +TEST(HttpAuthSSPITest, SplitUserAndDomain) { MatchDomainUserAfterSplit(L"foobar", L"", L"foobar"); MatchDomainUserAfterSplit(L"FOO\\bar", L"FOO", L"bar"); } +TEST(HttpAuthSSPITest, DetermineMaxTokenLength_Normal) { + SecPkgInfoW package_info; + memset(&package_info, 0x0, sizeof(package_info)); + package_info.cbMaxToken = 1337; + + ExpectedPackageQuerySSPILibrary mock_library; + mock_library.ExpectQuerySecurityPackageInfo(L"NTLM", SEC_E_OK, &package_info); + ULONG max_token_length = 100; + int rv = DetermineMaxTokenLength(&mock_library, L"NTLM", &max_token_length); + EXPECT_EQ(OK, rv); + EXPECT_EQ(1337, max_token_length); +} + +TEST(HttpAuthSSPITest, DetermineMaxTokenLength_InvalidPackage) { + ExpectedPackageQuerySSPILibrary mock_library; + mock_library.ExpectQuerySecurityPackageInfo(L"Foo", SEC_E_SECPKG_NOT_FOUND, + NULL); + ULONG max_token_length = 100; + int rv = DetermineMaxTokenLength(&mock_library, L"Foo", &max_token_length); + EXPECT_EQ(ERR_UNSUPPORTED_AUTH_SCHEME, rv); + // |DetermineMaxTokenLength()| interface states that |max_token_length| should + // not change on failure. + EXPECT_EQ(100, max_token_length); +} + } // namespace net |