diff options
author | ukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-10-27 09:40:11 +0000 |
---|---|---|
committer | ukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-10-27 09:40:11 +0000 |
commit | 5f7d8d759f8c2e5b1c88b3fa6398468868b592cb (patch) | |
tree | b89fe44530d7f293bacd16eb7799d20a25285264 /net/socket_stream | |
parent | 555d361a10df8ee24817128dfb867365d1f58f4e (diff) | |
download | chromium_src-5f7d8d759f8c2e5b1c88b3fa6398468868b592cb.zip chromium_src-5f7d8d759f8c2e5b1c88b3fa6398468868b592cb.tar.gz chromium_src-5f7d8d759f8c2e5b1c88b3fa6398468868b592cb.tar.bz2 |
Add proxy basic auth support in net/socket_stream.
BUG=none
TEST=net_unittests passes
Review URL: http://codereview.chromium.org/330016
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@30176 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/socket_stream')
-rw-r--r-- | net/socket_stream/socket_stream.cc | 153 | ||||
-rw-r--r-- | net/socket_stream/socket_stream.h | 37 | ||||
-rw-r--r-- | net/socket_stream/socket_stream_unittest.cc | 210 |
3 files changed, 395 insertions, 5 deletions
diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc index a90aca6..1dcf1f8 100644 --- a/net/socket_stream/socket_stream.cc +++ b/net/socket_stream/socket_stream.cc @@ -13,6 +13,7 @@ #include "base/logging.h" #include "base/message_loop.h" #include "base/string_util.h" +#include "net/base/auth.h" #include "net/base/host_resolver.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" @@ -149,6 +150,31 @@ void SocketStream::Close() { NewRunnableMethod(this, &SocketStream::DoLoop, OK)); } +void SocketStream::RestartWithAuth( + const std::wstring& username, const std::wstring& password) { + DCHECK(MessageLoop::current()) << + "The current MessageLoop must exist"; + DCHECK_EQ(MessageLoop::TYPE_IO, MessageLoop::current()->type()) << + "The current MessageLoop must be TYPE_IO"; + DCHECK(auth_handler_); + if (!socket_.get()) { + LOG(ERROR) << "Socket is closed before restarting with auth."; + return; + } + + if (auth_identity_.invalid) { + // Update the username/password. + auth_identity_.source = HttpAuth::IDENT_SRC_EXTERNAL; + auth_identity_.invalid = false; + auth_identity_.username = username; + auth_identity_.password = password; + } + + MessageLoop::current()->PostTask( + FROM_HERE, + NewRunnableMethod(this, &SocketStream::DoRestartWithAuth)); +} + void SocketStream::DetachDelegate() { if (!delegate_) return; @@ -161,6 +187,7 @@ void SocketStream::Finish() { "The current MessageLoop must exist"; DCHECK_EQ(MessageLoop::TYPE_IO, MessageLoop::current()->type()) << "The current MessageLoop must be TYPE_IO"; + DLOG(INFO) << "Finish"; Delegate* delegate = delegate_; delegate_ = NULL; if (delegate) { @@ -205,7 +232,7 @@ void SocketStream::DidReceiveData(int result) { void SocketStream::DidSendData(int result) { current_write_buf_ = NULL; - DCHECK(result > 0); + DCHECK_GT(result, 0); if (!delegate_) return; @@ -378,6 +405,7 @@ int SocketStream::DoResolveHost() { int SocketStream::DoResolveHostComplete(int result) { if (result == OK) next_state_ = STATE_TCP_CONNECT; + // TODO(ukai): if error occured, reconsider proxy after error. return result; } @@ -389,6 +417,7 @@ int SocketStream::DoTcpConnect() { } int SocketStream::DoTcpConnectComplete(int result) { + // TODO(ukai): if error occured, reconsider proxy after error. if (result != OK) return result; @@ -414,13 +443,43 @@ int SocketStream::DoWriteTunnelHeaders() { tunnel_request_headers_bytes_sent_ = 0; } if (tunnel_request_headers_->headers_.empty()) { + std::string authorization_headers; + + if (!auth_handler_.get()) { + // First attempt. Find auth from the proxy address. + HttpAuthCache::Entry* entry = auth_cache_.LookupByPath( + ProxyAuthOrigin(), std::string()); + if (entry && !entry->handler()->is_connection_based()) { + auth_identity_.source = HttpAuth::IDENT_SRC_PATH_LOOKUP; + auth_identity_.invalid = false; + auth_identity_.username = entry->username(); + auth_identity_.password = entry->password(); + auth_handler_ = entry->handler(); + } + } + + // Support basic authentication scheme only, because we don't have + // HttpRequestInfo. + // TODO(ukai): Add support other authentication scheme. + if (auth_handler_.get() && auth_handler_->scheme() == "basic") { + std::string credentials = auth_handler_->GenerateCredentials( + auth_identity_.username, + auth_identity_.password, + NULL, + &proxy_info_); + authorization_headers.append( + HttpAuth::GetAuthorizationHeaderName(HttpAuth::AUTH_PROXY) + + ": " + credentials + "\r\n"); + } + tunnel_request_headers_->headers_ = StringPrintf( "CONNECT %s HTTP/1.1\r\n" "Host: %s\r\n" "Proxy-Connection: keep-alive\r\n", GetHostAndPort(url_).c_str(), GetHostAndOptionalPort(url_).c_str()); - // TODO(ukai): set proxy auth if necessary. + if (!authorization_headers.empty()) + tunnel_request_headers_->headers_ += authorization_headers; tunnel_request_headers_->headers_ += "\r\n"; } tunnel_request_headers_->SetDataOffset(tunnel_request_headers_bytes_sent_); @@ -511,8 +570,21 @@ int SocketStream::DoReadTunnelHeadersComplete(int result) { } return OK; case 407: // Proxy Authentication Required. - // TODO(ukai): handle Proxy Authentication. - break; + result = HandleAuthChallenge(headers.get()); + if (result == ERR_PROXY_AUTH_REQUESTED && + auth_handler_.get() && delegate_) { + auth_info_ = new AuthChallengeInfo; + auth_info_->is_proxy = true; + auth_info_->host_and_port = + ASCIIToWide(proxy_info_.proxy_server().host_and_port()); + auth_info_->scheme = ASCIIToWide(auth_handler_->scheme()); + auth_info_->realm = ASCIIToWide(auth_handler_->realm()); + // Wait until RestartWithAuth or Close is called. + MessageLoop::current()->PostTask( + FROM_HERE, + NewRunnableMethod(this, &SocketStream::DoAuthRequired)); + return ERR_IO_PENDING; + } default: break; } @@ -615,6 +687,79 @@ int SocketStream::DoReadWrite(int result) { return ERR_IO_PENDING; } +GURL SocketStream::ProxyAuthOrigin() const { + return GURL("http://" + proxy_info_.proxy_server().host_and_port()); +} + +int SocketStream::HandleAuthChallenge(const HttpResponseHeaders* headers) { + GURL auth_origin(ProxyAuthOrigin()); + + LOG(INFO) << "The proxy " << auth_origin << " requested auth"; + + // The auth we tried just failed, hence it can't be valid. + // Remove it from the cache so it won't be used again. + if (auth_handler_.get() && !auth_identity_.invalid && + auth_handler_->IsFinalRound()) { + if (auth_identity_.source != HttpAuth::IDENT_SRC_PATH_LOOKUP) + auth_cache_.Remove(auth_origin, + auth_handler_->realm(), + auth_identity_.username, + auth_identity_.password); + auth_handler_ = NULL; + auth_identity_ = HttpAuth::Identity(); + } + + auth_identity_.invalid = true; + HttpAuth::ChooseBestChallenge(headers, HttpAuth::AUTH_PROXY, auth_origin, + &auth_handler_); + if (!auth_handler_) { + LOG(ERROR) << "Can't perform auth to the proxy " << auth_origin; + return ERR_TUNNEL_CONNECTION_FAILED; + } + if (auth_handler_->NeedsIdentity()) { + HttpAuthCache::Entry* entry = auth_cache_.LookupByRealm( + auth_origin, auth_handler_->realm()); + if (entry) { + if (entry->handler()->scheme() != "basic") { + // We only support basic authentication scheme now. + // TODO(ukai): Support other authentication scheme. + return ERR_TUNNEL_CONNECTION_FAILED; + } + auth_identity_.source = HttpAuth::IDENT_SRC_REALM_LOOKUP; + auth_identity_.invalid = false; + auth_identity_.username = entry->username(); + auth_identity_.password = entry->password(); + // Restart with auth info. + } + return ERR_PROXY_AUTH_REQUESTED; + } else { + auth_identity_.invalid = false; + } + return ERR_TUNNEL_CONNECTION_FAILED; +} + +void SocketStream::DoAuthRequired() { + if (delegate_ && auth_info_.get()) + delegate_->OnAuthRequired(this, auth_info_.get()); + else + DoLoop(net::ERR_UNEXPECTED); +} + +void SocketStream::DoRestartWithAuth() { + auth_cache_.Add(ProxyAuthOrigin(), auth_handler_, + auth_identity_.username, auth_identity_.password, + std::string()); + + tunnel_request_headers_ = NULL; + tunnel_request_headers_bytes_sent_ = 0; + tunnel_response_headers_ = NULL; + tunnel_response_headers_capacity_ = 0; + tunnel_response_headers_len_ = 0; + + next_state_ = STATE_TCP_CONNECT; + DoLoop(OK); +} + int SocketStream::HandleCertificateError(int result) { // TODO(ukai): handle cert error properly. switch (result) { diff --git a/net/socket_stream/socket_stream.h b/net/socket_stream/socket_stream.h index 2a3dcb2..c1aaf6e 100644 --- a/net/socket_stream/socket_stream.h +++ b/net/socket_stream/socket_stream.h @@ -6,6 +6,7 @@ #define NET_SOCKET_STREAM_SOCKET_STREAM_H_ #include <deque> +#include <map> #include <string> #include <vector> @@ -16,17 +17,27 @@ #include "net/base/address_list.h" #include "net/base/completion_callback.h" #include "net/base/io_buffer.h" +#include "net/http/http_auth.h" +#include "net/http/http_auth_cache.h" +#include "net/http/http_auth_handler.h" #include "net/proxy/proxy_service.h" #include "net/socket/tcp_client_socket.h" #include "net/url_request/url_request_context.h" namespace net { +class AuthChallengeInfo; class ClientSocketFactory; class HostResolver; class SSLConfigService; class SingleRequestHostResolver; +// SocketStream is used to implement Web Sockets. +// It provides plain full-duplex stream with proxy and SSL support. +// For proxy authentication, only basic mechanisum is supported. It will try +// authentication identity for proxy URL first. If server requires proxy +// authentication, it will try authentication identity for realm that server +// requests. class SocketStream : public base::RefCountedThreadSafe<SocketStream> { public: // Derive from this class and add your own data members to associate extra @@ -59,6 +70,15 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { // Called when the socket stream has been closed. virtual void OnClose(SocketStream* socket) = 0; + + // Called when proxy authentication required. + // The delegate should call RestartWithAuth() if credential for |auth_info| + // is found in password database, or call Close() to close the connection. + virtual void OnAuthRequired(SocketStream* socket, + AuthChallengeInfo* auth_info) { + // By default, no credential is available and close the connection. + socket->Close(); + } }; SocketStream(const GURL& url, Delegate* delegate); @@ -91,6 +111,12 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { // Once the connection is closed, calls delegate's OnClose. void Close(); + // Restarts with authentication info. + // Should be used for response of OnAuthRequired. + void RestartWithAuth( + const std::wstring& username, + const std::wstring& password); + // Detach delegate. Call before delegate is deleted. // Once delegate is detached, close the socket stream and never call delegate // back. @@ -145,7 +171,6 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { STATE_SOCKS_CONNECT_COMPLETE, STATE_SSL_CONNECT, STATE_SSL_CONNECT_COMPLETE, - STATE_CONNECTION_ESTABLISHED, STATE_READ_WRITE }; @@ -189,6 +214,11 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { int DoSSLConnectComplete(int result); int DoReadWrite(int result); + GURL ProxyAuthOrigin() const; + int HandleAuthChallenge(const HttpResponseHeaders* headers); + void DoAuthRequired(); + void DoRestartWithAuth(); + int HandleCertificateError(int result); bool is_secure() const; @@ -212,6 +242,11 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> { ProxyService::PacRequest* pac_request_; ProxyInfo proxy_info_; + HttpAuthCache auth_cache_; + scoped_refptr<HttpAuthHandler> auth_handler_; + HttpAuth::Identity auth_identity_; + scoped_refptr<AuthChallengeInfo> auth_info_; + scoped_refptr<RequestHeaders> tunnel_request_headers_; size_t tunnel_request_headers_bytes_sent_; scoped_refptr<ResponseHeaders> tunnel_response_headers_; diff --git a/net/socket_stream/socket_stream_unittest.cc b/net/socket_stream/socket_stream_unittest.cc new file mode 100644 index 0000000..5cef70e --- /dev/null +++ b/net/socket_stream/socket_stream_unittest.cc @@ -0,0 +1,210 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <string> +#include <vector> + +#include "net/base/mock_host_resolver.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/socket_test_util.h" +#include "net/socket_stream/socket_stream.h" +#include "net/url_request/url_request_unittest.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +struct SocketStreamEvent { + enum EventType { + EVENT_CONNECTED, EVENT_SENT_DATA, EVENT_RECEIVED_DATA, EVENT_CLOSE, + EVENT_AUTH_REQUIRED, + }; + + SocketStreamEvent(EventType type, net::SocketStream* socket_stream, + int num, const std::string& str, + net::AuthChallengeInfo* auth_challenge_info) + : event_type(type), socket(socket_stream), number(num), data(str), + auth_info(auth_challenge_info) {} + + EventType event_type; + net::SocketStream* socket; + int number; + std::string data; + scoped_refptr<net::AuthChallengeInfo> auth_info; +}; + +class SocketStreamEventRecorder : public net::SocketStream::Delegate { + public: + explicit SocketStreamEventRecorder(net::CompletionCallback* callback) + : on_connected_(NULL), + on_sent_data_(NULL), + on_received_data_(NULL), + on_close_(NULL), + on_auth_required_(NULL), + callback_(callback) {} + virtual ~SocketStreamEventRecorder() { + delete on_connected_; + delete on_sent_data_; + delete on_received_data_; + delete on_close_; + delete on_auth_required_; + } + + void SetOnConnected(Callback1<SocketStreamEvent*>::Type* callback) { + on_connected_ = callback; + } + void SetOnSentData(Callback1<SocketStreamEvent*>::Type* callback) { + on_sent_data_ = callback; + } + void SetOnReceivedData(Callback1<SocketStreamEvent*>::Type* callback) { + on_received_data_ = callback; + } + void SetOnClose(Callback1<SocketStreamEvent*>::Type* callback) { + on_close_ = callback; + } + void SetOnAuthRequired(Callback1<SocketStreamEvent*>::Type* callback) { + on_auth_required_ = callback; + } + + virtual void OnConnected(net::SocketStream* socket, + int num_pending_send_allowed) { + events_.push_back( + SocketStreamEvent(SocketStreamEvent::EVENT_CONNECTED, + socket, num_pending_send_allowed, std::string(), + NULL)); + if (on_connected_) + on_connected_->Run(&events_.back()); + } + virtual void OnSentData(net::SocketStream* socket, + int amount_sent) { + events_.push_back( + SocketStreamEvent(SocketStreamEvent::EVENT_SENT_DATA, + socket, amount_sent, std::string(), NULL)); + if (on_sent_data_) + on_sent_data_->Run(&events_.back()); + } + virtual void OnReceivedData(net::SocketStream* socket, + const char* data, int len) { + events_.push_back( + SocketStreamEvent(SocketStreamEvent::EVENT_RECEIVED_DATA, + socket, len, std::string(data, len), NULL)); + if (on_received_data_) + on_received_data_->Run(&events_.back()); + } + virtual void OnClose(net::SocketStream* socket) { + events_.push_back( + SocketStreamEvent(SocketStreamEvent::EVENT_CLOSE, + socket, 0, std::string(), NULL)); + if (on_close_) + on_close_->Run(&events_.back()); + if (callback_) + callback_->Run(net::OK); + } + virtual void OnAuthRequired(net::SocketStream* socket, + net::AuthChallengeInfo* auth_info) { + events_.push_back( + SocketStreamEvent(SocketStreamEvent::EVENT_AUTH_REQUIRED, + socket, 0, std::string(), auth_info)); + if (on_auth_required_) + on_auth_required_->Run(&events_.back()); + } + + void DoClose(SocketStreamEvent* event) { + event->socket->Close(); + } + void DoRestartWithAuth(SocketStreamEvent* event) { + LOG(INFO) << "RestartWithAuth username=" << username_ + << " password=" << password_; + event->socket->RestartWithAuth(username_, password_); + } + void SetAuthInfo(const std::wstring& username, + const std::wstring& password) { + username_ = username; + password_ = password; + } + + const std::vector<SocketStreamEvent>& GetSeenEvents() const { + return events_; + } + + private: + std::vector<SocketStreamEvent> events_; + Callback1<SocketStreamEvent*>::Type* on_connected_; + Callback1<SocketStreamEvent*>::Type* on_sent_data_; + Callback1<SocketStreamEvent*>::Type* on_received_data_; + Callback1<SocketStreamEvent*>::Type* on_close_; + Callback1<SocketStreamEvent*>::Type* on_auth_required_; + net::CompletionCallback* callback_; + + std::wstring username_; + std::wstring password_; + + DISALLOW_COPY_AND_ASSIGN(SocketStreamEventRecorder); +}; + +namespace net { + +class SocketStreamTest : public PlatformTest { +}; + +TEST_F(SocketStreamTest, BasicAuthProxy) { + MockClientSocketFactory mock_socket_factory; + MockWrite data_writes1[] = { + MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" + "Host: example.com\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"), + }; + MockRead data_reads1[] = { + MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"), + MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"), + MockRead("\r\n"), + }; + StaticMockSocket data1(data_reads1, data_writes1); + mock_socket_factory.AddMockSocket(&data1); + + MockWrite data_writes2[] = { + MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" + "Host: example.com\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), + }; + MockRead data_reads2[] = { + MockRead("HTTP/1.1 200 Connection Established\r\n"), + MockRead("Proxy-agent: Apache/2.2.8\r\n"), + MockRead("\r\n"), + }; + StaticMockSocket data2(data_reads2, data_writes2); + mock_socket_factory.AddMockSocket(&data2); + + TestCompletionCallback callback; + + scoped_ptr<SocketStreamEventRecorder> delegate( + new SocketStreamEventRecorder(&callback)); + delegate->SetOnConnected(NewCallback(delegate.get(), + &SocketStreamEventRecorder::DoClose)); + const std::wstring kUsername = L"foo"; + const std::wstring kPassword = L"bar"; + delegate->SetAuthInfo(kUsername, kPassword); + delegate->SetOnAuthRequired( + NewCallback(delegate.get(), + &SocketStreamEventRecorder::DoRestartWithAuth)); + + scoped_refptr<SocketStream> socket_stream = + new SocketStream(GURL("ws://example.com/demo"), delegate.get()); + + socket_stream->set_context(new TestURLRequestContext("myproxy:70")); + socket_stream->SetHostResolver(new MockHostResolver()); + socket_stream->SetClientSocketFactory(&mock_socket_factory); + + socket_stream->Connect(); + + callback.WaitForResult(); + + const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); + EXPECT_EQ(3U, events.size()); + + EXPECT_EQ(SocketStreamEvent::EVENT_AUTH_REQUIRED, events[0].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type); + EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[2].event_type); +} + +} // namespace net |