summaryrefslogtreecommitdiffstats
path: root/net/socket_stream
diff options
context:
space:
mode:
authorukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2009-10-27 09:40:11 +0000
committerukai@chromium.org <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2009-10-27 09:40:11 +0000
commit5f7d8d759f8c2e5b1c88b3fa6398468868b592cb (patch)
treeb89fe44530d7f293bacd16eb7799d20a25285264 /net/socket_stream
parent555d361a10df8ee24817128dfb867365d1f58f4e (diff)
downloadchromium_src-5f7d8d759f8c2e5b1c88b3fa6398468868b592cb.zip
chromium_src-5f7d8d759f8c2e5b1c88b3fa6398468868b592cb.tar.gz
chromium_src-5f7d8d759f8c2e5b1c88b3fa6398468868b592cb.tar.bz2
Add proxy basic auth support in net/socket_stream.
BUG=none TEST=net_unittests passes Review URL: http://codereview.chromium.org/330016 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@30176 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'net/socket_stream')
-rw-r--r--net/socket_stream/socket_stream.cc153
-rw-r--r--net/socket_stream/socket_stream.h37
-rw-r--r--net/socket_stream/socket_stream_unittest.cc210
3 files changed, 395 insertions, 5 deletions
diff --git a/net/socket_stream/socket_stream.cc b/net/socket_stream/socket_stream.cc
index a90aca6..1dcf1f8 100644
--- a/net/socket_stream/socket_stream.cc
+++ b/net/socket_stream/socket_stream.cc
@@ -13,6 +13,7 @@
#include "base/logging.h"
#include "base/message_loop.h"
#include "base/string_util.h"
+#include "net/base/auth.h"
#include "net/base/host_resolver.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
@@ -149,6 +150,31 @@ void SocketStream::Close() {
NewRunnableMethod(this, &SocketStream::DoLoop, OK));
}
+void SocketStream::RestartWithAuth(
+ const std::wstring& username, const std::wstring& password) {
+ DCHECK(MessageLoop::current()) <<
+ "The current MessageLoop must exist";
+ DCHECK_EQ(MessageLoop::TYPE_IO, MessageLoop::current()->type()) <<
+ "The current MessageLoop must be TYPE_IO";
+ DCHECK(auth_handler_);
+ if (!socket_.get()) {
+ LOG(ERROR) << "Socket is closed before restarting with auth.";
+ return;
+ }
+
+ if (auth_identity_.invalid) {
+ // Update the username/password.
+ auth_identity_.source = HttpAuth::IDENT_SRC_EXTERNAL;
+ auth_identity_.invalid = false;
+ auth_identity_.username = username;
+ auth_identity_.password = password;
+ }
+
+ MessageLoop::current()->PostTask(
+ FROM_HERE,
+ NewRunnableMethod(this, &SocketStream::DoRestartWithAuth));
+}
+
void SocketStream::DetachDelegate() {
if (!delegate_)
return;
@@ -161,6 +187,7 @@ void SocketStream::Finish() {
"The current MessageLoop must exist";
DCHECK_EQ(MessageLoop::TYPE_IO, MessageLoop::current()->type()) <<
"The current MessageLoop must be TYPE_IO";
+ DLOG(INFO) << "Finish";
Delegate* delegate = delegate_;
delegate_ = NULL;
if (delegate) {
@@ -205,7 +232,7 @@ void SocketStream::DidReceiveData(int result) {
void SocketStream::DidSendData(int result) {
current_write_buf_ = NULL;
- DCHECK(result > 0);
+ DCHECK_GT(result, 0);
if (!delegate_)
return;
@@ -378,6 +405,7 @@ int SocketStream::DoResolveHost() {
int SocketStream::DoResolveHostComplete(int result) {
if (result == OK)
next_state_ = STATE_TCP_CONNECT;
+ // TODO(ukai): if error occured, reconsider proxy after error.
return result;
}
@@ -389,6 +417,7 @@ int SocketStream::DoTcpConnect() {
}
int SocketStream::DoTcpConnectComplete(int result) {
+ // TODO(ukai): if error occured, reconsider proxy after error.
if (result != OK)
return result;
@@ -414,13 +443,43 @@ int SocketStream::DoWriteTunnelHeaders() {
tunnel_request_headers_bytes_sent_ = 0;
}
if (tunnel_request_headers_->headers_.empty()) {
+ std::string authorization_headers;
+
+ if (!auth_handler_.get()) {
+ // First attempt. Find auth from the proxy address.
+ HttpAuthCache::Entry* entry = auth_cache_.LookupByPath(
+ ProxyAuthOrigin(), std::string());
+ if (entry && !entry->handler()->is_connection_based()) {
+ auth_identity_.source = HttpAuth::IDENT_SRC_PATH_LOOKUP;
+ auth_identity_.invalid = false;
+ auth_identity_.username = entry->username();
+ auth_identity_.password = entry->password();
+ auth_handler_ = entry->handler();
+ }
+ }
+
+ // Support basic authentication scheme only, because we don't have
+ // HttpRequestInfo.
+ // TODO(ukai): Add support other authentication scheme.
+ if (auth_handler_.get() && auth_handler_->scheme() == "basic") {
+ std::string credentials = auth_handler_->GenerateCredentials(
+ auth_identity_.username,
+ auth_identity_.password,
+ NULL,
+ &proxy_info_);
+ authorization_headers.append(
+ HttpAuth::GetAuthorizationHeaderName(HttpAuth::AUTH_PROXY) +
+ ": " + credentials + "\r\n");
+ }
+
tunnel_request_headers_->headers_ = StringPrintf(
"CONNECT %s HTTP/1.1\r\n"
"Host: %s\r\n"
"Proxy-Connection: keep-alive\r\n",
GetHostAndPort(url_).c_str(),
GetHostAndOptionalPort(url_).c_str());
- // TODO(ukai): set proxy auth if necessary.
+ if (!authorization_headers.empty())
+ tunnel_request_headers_->headers_ += authorization_headers;
tunnel_request_headers_->headers_ += "\r\n";
}
tunnel_request_headers_->SetDataOffset(tunnel_request_headers_bytes_sent_);
@@ -511,8 +570,21 @@ int SocketStream::DoReadTunnelHeadersComplete(int result) {
}
return OK;
case 407: // Proxy Authentication Required.
- // TODO(ukai): handle Proxy Authentication.
- break;
+ result = HandleAuthChallenge(headers.get());
+ if (result == ERR_PROXY_AUTH_REQUESTED &&
+ auth_handler_.get() && delegate_) {
+ auth_info_ = new AuthChallengeInfo;
+ auth_info_->is_proxy = true;
+ auth_info_->host_and_port =
+ ASCIIToWide(proxy_info_.proxy_server().host_and_port());
+ auth_info_->scheme = ASCIIToWide(auth_handler_->scheme());
+ auth_info_->realm = ASCIIToWide(auth_handler_->realm());
+ // Wait until RestartWithAuth or Close is called.
+ MessageLoop::current()->PostTask(
+ FROM_HERE,
+ NewRunnableMethod(this, &SocketStream::DoAuthRequired));
+ return ERR_IO_PENDING;
+ }
default:
break;
}
@@ -615,6 +687,79 @@ int SocketStream::DoReadWrite(int result) {
return ERR_IO_PENDING;
}
+GURL SocketStream::ProxyAuthOrigin() const {
+ return GURL("http://" + proxy_info_.proxy_server().host_and_port());
+}
+
+int SocketStream::HandleAuthChallenge(const HttpResponseHeaders* headers) {
+ GURL auth_origin(ProxyAuthOrigin());
+
+ LOG(INFO) << "The proxy " << auth_origin << " requested auth";
+
+ // The auth we tried just failed, hence it can't be valid.
+ // Remove it from the cache so it won't be used again.
+ if (auth_handler_.get() && !auth_identity_.invalid &&
+ auth_handler_->IsFinalRound()) {
+ if (auth_identity_.source != HttpAuth::IDENT_SRC_PATH_LOOKUP)
+ auth_cache_.Remove(auth_origin,
+ auth_handler_->realm(),
+ auth_identity_.username,
+ auth_identity_.password);
+ auth_handler_ = NULL;
+ auth_identity_ = HttpAuth::Identity();
+ }
+
+ auth_identity_.invalid = true;
+ HttpAuth::ChooseBestChallenge(headers, HttpAuth::AUTH_PROXY, auth_origin,
+ &auth_handler_);
+ if (!auth_handler_) {
+ LOG(ERROR) << "Can't perform auth to the proxy " << auth_origin;
+ return ERR_TUNNEL_CONNECTION_FAILED;
+ }
+ if (auth_handler_->NeedsIdentity()) {
+ HttpAuthCache::Entry* entry = auth_cache_.LookupByRealm(
+ auth_origin, auth_handler_->realm());
+ if (entry) {
+ if (entry->handler()->scheme() != "basic") {
+ // We only support basic authentication scheme now.
+ // TODO(ukai): Support other authentication scheme.
+ return ERR_TUNNEL_CONNECTION_FAILED;
+ }
+ auth_identity_.source = HttpAuth::IDENT_SRC_REALM_LOOKUP;
+ auth_identity_.invalid = false;
+ auth_identity_.username = entry->username();
+ auth_identity_.password = entry->password();
+ // Restart with auth info.
+ }
+ return ERR_PROXY_AUTH_REQUESTED;
+ } else {
+ auth_identity_.invalid = false;
+ }
+ return ERR_TUNNEL_CONNECTION_FAILED;
+}
+
+void SocketStream::DoAuthRequired() {
+ if (delegate_ && auth_info_.get())
+ delegate_->OnAuthRequired(this, auth_info_.get());
+ else
+ DoLoop(net::ERR_UNEXPECTED);
+}
+
+void SocketStream::DoRestartWithAuth() {
+ auth_cache_.Add(ProxyAuthOrigin(), auth_handler_,
+ auth_identity_.username, auth_identity_.password,
+ std::string());
+
+ tunnel_request_headers_ = NULL;
+ tunnel_request_headers_bytes_sent_ = 0;
+ tunnel_response_headers_ = NULL;
+ tunnel_response_headers_capacity_ = 0;
+ tunnel_response_headers_len_ = 0;
+
+ next_state_ = STATE_TCP_CONNECT;
+ DoLoop(OK);
+}
+
int SocketStream::HandleCertificateError(int result) {
// TODO(ukai): handle cert error properly.
switch (result) {
diff --git a/net/socket_stream/socket_stream.h b/net/socket_stream/socket_stream.h
index 2a3dcb2..c1aaf6e 100644
--- a/net/socket_stream/socket_stream.h
+++ b/net/socket_stream/socket_stream.h
@@ -6,6 +6,7 @@
#define NET_SOCKET_STREAM_SOCKET_STREAM_H_
#include <deque>
+#include <map>
#include <string>
#include <vector>
@@ -16,17 +17,27 @@
#include "net/base/address_list.h"
#include "net/base/completion_callback.h"
#include "net/base/io_buffer.h"
+#include "net/http/http_auth.h"
+#include "net/http/http_auth_cache.h"
+#include "net/http/http_auth_handler.h"
#include "net/proxy/proxy_service.h"
#include "net/socket/tcp_client_socket.h"
#include "net/url_request/url_request_context.h"
namespace net {
+class AuthChallengeInfo;
class ClientSocketFactory;
class HostResolver;
class SSLConfigService;
class SingleRequestHostResolver;
+// SocketStream is used to implement Web Sockets.
+// It provides plain full-duplex stream with proxy and SSL support.
+// For proxy authentication, only basic mechanisum is supported. It will try
+// authentication identity for proxy URL first. If server requires proxy
+// authentication, it will try authentication identity for realm that server
+// requests.
class SocketStream : public base::RefCountedThreadSafe<SocketStream> {
public:
// Derive from this class and add your own data members to associate extra
@@ -59,6 +70,15 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> {
// Called when the socket stream has been closed.
virtual void OnClose(SocketStream* socket) = 0;
+
+ // Called when proxy authentication required.
+ // The delegate should call RestartWithAuth() if credential for |auth_info|
+ // is found in password database, or call Close() to close the connection.
+ virtual void OnAuthRequired(SocketStream* socket,
+ AuthChallengeInfo* auth_info) {
+ // By default, no credential is available and close the connection.
+ socket->Close();
+ }
};
SocketStream(const GURL& url, Delegate* delegate);
@@ -91,6 +111,12 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> {
// Once the connection is closed, calls delegate's OnClose.
void Close();
+ // Restarts with authentication info.
+ // Should be used for response of OnAuthRequired.
+ void RestartWithAuth(
+ const std::wstring& username,
+ const std::wstring& password);
+
// Detach delegate. Call before delegate is deleted.
// Once delegate is detached, close the socket stream and never call delegate
// back.
@@ -145,7 +171,6 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> {
STATE_SOCKS_CONNECT_COMPLETE,
STATE_SSL_CONNECT,
STATE_SSL_CONNECT_COMPLETE,
- STATE_CONNECTION_ESTABLISHED,
STATE_READ_WRITE
};
@@ -189,6 +214,11 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> {
int DoSSLConnectComplete(int result);
int DoReadWrite(int result);
+ GURL ProxyAuthOrigin() const;
+ int HandleAuthChallenge(const HttpResponseHeaders* headers);
+ void DoAuthRequired();
+ void DoRestartWithAuth();
+
int HandleCertificateError(int result);
bool is_secure() const;
@@ -212,6 +242,11 @@ class SocketStream : public base::RefCountedThreadSafe<SocketStream> {
ProxyService::PacRequest* pac_request_;
ProxyInfo proxy_info_;
+ HttpAuthCache auth_cache_;
+ scoped_refptr<HttpAuthHandler> auth_handler_;
+ HttpAuth::Identity auth_identity_;
+ scoped_refptr<AuthChallengeInfo> auth_info_;
+
scoped_refptr<RequestHeaders> tunnel_request_headers_;
size_t tunnel_request_headers_bytes_sent_;
scoped_refptr<ResponseHeaders> tunnel_response_headers_;
diff --git a/net/socket_stream/socket_stream_unittest.cc b/net/socket_stream/socket_stream_unittest.cc
new file mode 100644
index 0000000..5cef70e
--- /dev/null
+++ b/net/socket_stream/socket_stream_unittest.cc
@@ -0,0 +1,210 @@
+// Copyright (c) 2009 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <string>
+#include <vector>
+
+#include "net/base/mock_host_resolver.h"
+#include "net/base/test_completion_callback.h"
+#include "net/socket/socket_test_util.h"
+#include "net/socket_stream/socket_stream.h"
+#include "net/url_request/url_request_unittest.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "testing/platform_test.h"
+
+struct SocketStreamEvent {
+ enum EventType {
+ EVENT_CONNECTED, EVENT_SENT_DATA, EVENT_RECEIVED_DATA, EVENT_CLOSE,
+ EVENT_AUTH_REQUIRED,
+ };
+
+ SocketStreamEvent(EventType type, net::SocketStream* socket_stream,
+ int num, const std::string& str,
+ net::AuthChallengeInfo* auth_challenge_info)
+ : event_type(type), socket(socket_stream), number(num), data(str),
+ auth_info(auth_challenge_info) {}
+
+ EventType event_type;
+ net::SocketStream* socket;
+ int number;
+ std::string data;
+ scoped_refptr<net::AuthChallengeInfo> auth_info;
+};
+
+class SocketStreamEventRecorder : public net::SocketStream::Delegate {
+ public:
+ explicit SocketStreamEventRecorder(net::CompletionCallback* callback)
+ : on_connected_(NULL),
+ on_sent_data_(NULL),
+ on_received_data_(NULL),
+ on_close_(NULL),
+ on_auth_required_(NULL),
+ callback_(callback) {}
+ virtual ~SocketStreamEventRecorder() {
+ delete on_connected_;
+ delete on_sent_data_;
+ delete on_received_data_;
+ delete on_close_;
+ delete on_auth_required_;
+ }
+
+ void SetOnConnected(Callback1<SocketStreamEvent*>::Type* callback) {
+ on_connected_ = callback;
+ }
+ void SetOnSentData(Callback1<SocketStreamEvent*>::Type* callback) {
+ on_sent_data_ = callback;
+ }
+ void SetOnReceivedData(Callback1<SocketStreamEvent*>::Type* callback) {
+ on_received_data_ = callback;
+ }
+ void SetOnClose(Callback1<SocketStreamEvent*>::Type* callback) {
+ on_close_ = callback;
+ }
+ void SetOnAuthRequired(Callback1<SocketStreamEvent*>::Type* callback) {
+ on_auth_required_ = callback;
+ }
+
+ virtual void OnConnected(net::SocketStream* socket,
+ int num_pending_send_allowed) {
+ events_.push_back(
+ SocketStreamEvent(SocketStreamEvent::EVENT_CONNECTED,
+ socket, num_pending_send_allowed, std::string(),
+ NULL));
+ if (on_connected_)
+ on_connected_->Run(&events_.back());
+ }
+ virtual void OnSentData(net::SocketStream* socket,
+ int amount_sent) {
+ events_.push_back(
+ SocketStreamEvent(SocketStreamEvent::EVENT_SENT_DATA,
+ socket, amount_sent, std::string(), NULL));
+ if (on_sent_data_)
+ on_sent_data_->Run(&events_.back());
+ }
+ virtual void OnReceivedData(net::SocketStream* socket,
+ const char* data, int len) {
+ events_.push_back(
+ SocketStreamEvent(SocketStreamEvent::EVENT_RECEIVED_DATA,
+ socket, len, std::string(data, len), NULL));
+ if (on_received_data_)
+ on_received_data_->Run(&events_.back());
+ }
+ virtual void OnClose(net::SocketStream* socket) {
+ events_.push_back(
+ SocketStreamEvent(SocketStreamEvent::EVENT_CLOSE,
+ socket, 0, std::string(), NULL));
+ if (on_close_)
+ on_close_->Run(&events_.back());
+ if (callback_)
+ callback_->Run(net::OK);
+ }
+ virtual void OnAuthRequired(net::SocketStream* socket,
+ net::AuthChallengeInfo* auth_info) {
+ events_.push_back(
+ SocketStreamEvent(SocketStreamEvent::EVENT_AUTH_REQUIRED,
+ socket, 0, std::string(), auth_info));
+ if (on_auth_required_)
+ on_auth_required_->Run(&events_.back());
+ }
+
+ void DoClose(SocketStreamEvent* event) {
+ event->socket->Close();
+ }
+ void DoRestartWithAuth(SocketStreamEvent* event) {
+ LOG(INFO) << "RestartWithAuth username=" << username_
+ << " password=" << password_;
+ event->socket->RestartWithAuth(username_, password_);
+ }
+ void SetAuthInfo(const std::wstring& username,
+ const std::wstring& password) {
+ username_ = username;
+ password_ = password;
+ }
+
+ const std::vector<SocketStreamEvent>& GetSeenEvents() const {
+ return events_;
+ }
+
+ private:
+ std::vector<SocketStreamEvent> events_;
+ Callback1<SocketStreamEvent*>::Type* on_connected_;
+ Callback1<SocketStreamEvent*>::Type* on_sent_data_;
+ Callback1<SocketStreamEvent*>::Type* on_received_data_;
+ Callback1<SocketStreamEvent*>::Type* on_close_;
+ Callback1<SocketStreamEvent*>::Type* on_auth_required_;
+ net::CompletionCallback* callback_;
+
+ std::wstring username_;
+ std::wstring password_;
+
+ DISALLOW_COPY_AND_ASSIGN(SocketStreamEventRecorder);
+};
+
+namespace net {
+
+class SocketStreamTest : public PlatformTest {
+};
+
+TEST_F(SocketStreamTest, BasicAuthProxy) {
+ MockClientSocketFactory mock_socket_factory;
+ MockWrite data_writes1[] = {
+ MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
+ "Host: example.com\r\n"
+ "Proxy-Connection: keep-alive\r\n\r\n"),
+ };
+ MockRead data_reads1[] = {
+ MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"),
+ MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"),
+ MockRead("\r\n"),
+ };
+ StaticMockSocket data1(data_reads1, data_writes1);
+ mock_socket_factory.AddMockSocket(&data1);
+
+ MockWrite data_writes2[] = {
+ MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
+ "Host: example.com\r\n"
+ "Proxy-Connection: keep-alive\r\n"
+ "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"),
+ };
+ MockRead data_reads2[] = {
+ MockRead("HTTP/1.1 200 Connection Established\r\n"),
+ MockRead("Proxy-agent: Apache/2.2.8\r\n"),
+ MockRead("\r\n"),
+ };
+ StaticMockSocket data2(data_reads2, data_writes2);
+ mock_socket_factory.AddMockSocket(&data2);
+
+ TestCompletionCallback callback;
+
+ scoped_ptr<SocketStreamEventRecorder> delegate(
+ new SocketStreamEventRecorder(&callback));
+ delegate->SetOnConnected(NewCallback(delegate.get(),
+ &SocketStreamEventRecorder::DoClose));
+ const std::wstring kUsername = L"foo";
+ const std::wstring kPassword = L"bar";
+ delegate->SetAuthInfo(kUsername, kPassword);
+ delegate->SetOnAuthRequired(
+ NewCallback(delegate.get(),
+ &SocketStreamEventRecorder::DoRestartWithAuth));
+
+ scoped_refptr<SocketStream> socket_stream =
+ new SocketStream(GURL("ws://example.com/demo"), delegate.get());
+
+ socket_stream->set_context(new TestURLRequestContext("myproxy:70"));
+ socket_stream->SetHostResolver(new MockHostResolver());
+ socket_stream->SetClientSocketFactory(&mock_socket_factory);
+
+ socket_stream->Connect();
+
+ callback.WaitForResult();
+
+ const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
+ EXPECT_EQ(3U, events.size());
+
+ EXPECT_EQ(SocketStreamEvent::EVENT_AUTH_REQUIRED, events[0].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
+ EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[2].event_type);
+}
+
+} // namespace net