diff options
author | rtenneti@chromium.org <rtenneti@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-08-16 03:30:45 +0000 |
---|---|---|
committer | rtenneti@chromium.org <rtenneti@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-08-16 03:30:45 +0000 |
commit | cd01a1310062be0b0eb5589c62e4fb0951c90bac (patch) | |
tree | 3e1152f901ba794c4b0e9c1b3c8d3d158529ccd9 /chrome/browser/net | |
parent | 91bc704081b6771e423e0ebb69553634c048c14e (diff) | |
download | chromium_src-cd01a1310062be0b0eb5589c62e4fb0951c90bac.zip chromium_src-cd01a1310062be0b0eb5589c62e4fb0951c90bac.tar.gz chromium_src-cd01a1310062be0b0eb5589c62e4fb0951c90bac.tar.bz2 |
Prevent DOS attack on UDP echo servers by distinguishing between an echo request
and the echo response.
Client sends <version><checksum><size><payload> data to TCP/UDP echo servers.
<checksum> is the checksum of the <payload>. For the first cut, we will sum up
the characters in <payload>.
If checksum of the <payload> is verified, echo servers encrypt the data
and send back the data as <version><checksum><size><key><encrypted_payload>.
<key> is is used to decrypt the <encrypted_payload>. <encrypted_payload> is
the encrypted <payload>.
BUG=87297
R=jar
TEST=network_stats unit tests.
Review URL: http://codereview.chromium.org/7246021
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@96890 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'chrome/browser/net')
-rw-r--r-- | chrome/browser/net/network_stats.cc | 225 | ||||
-rw-r--r-- | chrome/browser/net/network_stats.h | 64 |
2 files changed, 251 insertions, 38 deletions
diff --git a/chrome/browser/net/network_stats.cc b/chrome/browser/net/network_stats.cc index f6f745a..fc16084 100644 --- a/chrome/browser/net/network_stats.cc +++ b/chrome/browser/net/network_stats.cc @@ -27,35 +27,79 @@ namespace chrome_browser_net { // This specifies the number of bytes to be sent to the TCP/UDP servers as part // of small packet size test. -static const int kSmallTestBytesToSend = 100; +static const uint32 kSmallTestBytesToSend = 100; // This specifies the number of bytes to be sent to the TCP/UDP servers as part // of large packet size test. -static const int kLargeTestBytesToSend = 1200; +static const uint32 kLargeTestBytesToSend = 1200; + +// This specifies the maximum message (payload) size. +static const uint32 kMaxMessage = 2048; + +// This specifies starting position of the <version> and length of the +// <version> in "echo request" and "echo response". +static const uint32 kVersionNumber = 1; +static const uint32 kVersionStart = 0; +static const uint32 kVersionLength = 2; +static const uint32 kVersionEnd = kVersionStart + kVersionLength; + +// This specifies the starting position of the <checksum> and length of the +// <checksum> in "echo request" and "echo response". Maximum value for the +// <checksum> is less than (2 ** 31 - 1). +static const uint32 kChecksumStart = kVersionEnd; +static const uint32 kChecksumLength = 10; +static const uint32 kChecksumEnd = kChecksumStart + kChecksumLength; + +// This specifies the starting position of the <payload_size> and length of the +// <payload_size> in "echo request" and "echo response". Maximum number of bytes +// that can be sent in the <payload> is 9,999,999. +static const uint32 kPayloadSizeStart = kChecksumEnd; +static const uint32 kPayloadSizeLength = 7; +static const uint32 kPayloadSizeEnd = kPayloadSizeStart + kPayloadSizeLength; + +// This specifies the starting position of the <key> and length of the <key> in +// "echo response". +static const uint32 kKeyStart = kPayloadSizeEnd; +static const uint32 kKeyLength = 6; +static const uint32 kKeyEnd = kKeyStart + kKeyLength; +static const int32 kKeyMinValue = 0; +static const int32 kKeyMaxValue = 999999; + +// This specifies the starting position of the <payload> in "echo request". +static const uint32 kPayloadStart = kPayloadSizeEnd; + +// This specifies the starting position of the <encoded_payload> and length of +// the <encoded_payload> in "echo response". +static const uint32 kEncodedPayloadStart = kKeyEnd; // NetworkStats methods and members. NetworkStats::NetworkStats() - : bytes_to_read_(0), + : load_size_(0), + bytes_to_read_(0), bytes_to_send_(0), + encoded_message_(""), ALLOW_THIS_IN_INITIALIZER_LIST( read_callback_(this, &NetworkStats::OnReadComplete)), ALLOW_THIS_IN_INITIALIZER_LIST( write_callback_(this, &NetworkStats::OnWriteComplete)), finished_callback_(NULL), - start_time_(base::TimeTicks::Now()) { + start_time_(base::TimeTicks::Now()), + ALLOW_THIS_IN_INITIALIZER_LIST(timers_factory_(this)) { } NetworkStats::~NetworkStats() { socket_.reset(); } -void NetworkStats::Initialize(int bytes_to_send, +void NetworkStats::Initialize(uint32 bytes_to_send, net::CompletionCallback* finished_callback) { DCHECK(bytes_to_send); // We should have data to send. load_size_ = bytes_to_send; - bytes_to_send_ = bytes_to_send; - bytes_to_read_ = bytes_to_send; + bytes_to_send_ = kVersionLength + kChecksumLength + kPayloadSizeLength + + load_size_; + bytes_to_read_ = kVersionLength + kChecksumLength + kPayloadSizeLength + + kKeyLength + load_size_; finished_callback_ = finished_callback; } @@ -78,6 +122,11 @@ bool NetworkStats::DoStart(int result) { } stream_.Reset(); + + // Timeout if we don't get response back from echo servers in 60 secs. + const int kReadDataTimeoutMs = 60000; + StartReadDataTimer(kReadDataTimeoutMs); + ReadData(); return true; @@ -105,25 +154,34 @@ bool NetworkStats::ReadComplete(int result) { return true; } - if (!stream_.VerifyBytes(read_buffer_->data(), result)) { - Finish(READ_VERIFY_FAILED, net::ERR_INVALID_RESPONSE); - return true; - } + encoded_message_.append(read_buffer_->data(), result); read_buffer_ = NULL; bytes_to_read_ -= result; // No more data to read. - if (!bytes_to_read_) { - Finish(SUCCESS, net::OK); + if (!bytes_to_read_ || result == 0) { + if (VerifyBytes()) + Finish(SUCCESS, net::OK); + else + Finish(READ_VERIFY_FAILED, net::ERR_INVALID_RESPONSE); return true; } - ReadData(); return false; } void NetworkStats::OnReadComplete(int result) { - ReadComplete(result); + if (!ReadComplete(result)) { + // Called ReadData() via PostDelayedTask() to avoid recursion. Added a delay + // of 1ms so that the time-out will fire before we have time to really hog + // the CPU too extensively (waiting for the time-out) in case of an infinite + // loop. + const int kReadDataDelayMs = 1; + MessageLoop::current()->PostDelayedTask( + FROM_HERE, + timers_factory_.NewRunnableMethod(&NetworkStats::ReadData), + kReadDataDelayMs); + } } void NetworkStats::OnWriteComplete(int result) { @@ -151,19 +209,21 @@ void NetworkStats::OnWriteComplete(int result) { } void NetworkStats::ReadData() { - DCHECK(!read_buffer_.get()); - int kMaxMessage = 2048; - - // We release the read_buffer_ in the destructor if there is an error. - read_buffer_ = new net::IOBuffer(kMaxMessage); - int rv; do { - DCHECK(socket_.get()); + if (!socket_.get()) + return; + + DCHECK(!read_buffer_.get()); + + // We release the read_buffer_ in the destructor if there is an error. + read_buffer_ = new net::IOBuffer(kMaxMessage); + rv = socket_->Read(read_buffer_, kMaxMessage, &read_callback_); if (rv == net::ERR_IO_PENDING) return; - if (ReadComplete(rv)) // Complete the read manually. + // If we have read all the data then return. + if (ReadComplete(rv)) return; } while (rv > 0); } @@ -173,11 +233,12 @@ int NetworkStats::SendData() { do { if (!write_buffer_.get()) { scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(bytes_to_send_)); - stream_.GetBytes(buffer->data(), bytes_to_send_); + GetEchoRequest(buffer); write_buffer_ = new net::DrainableIOBuffer(buffer, bytes_to_send_); } - DCHECK(socket_.get()); + if (!socket_.get()) + return net::ERR_UNEXPECTED; int rv = socket_->Write(write_buffer_, write_buffer_->BytesRemaining(), &write_callback_); @@ -191,6 +252,88 @@ int NetworkStats::SendData() { return net::OK; } +void NetworkStats::StartReadDataTimer(int milliseconds) { + MessageLoop::current()->PostDelayedTask( + FROM_HERE, + timers_factory_.NewRunnableMethod(&NetworkStats::OnReadDataTimeout), + milliseconds); +} + +void NetworkStats::OnReadDataTimeout() { + Finish(READ_TIMED_OUT, net::ERR_INVALID_ARGUMENT); +} + +void NetworkStats::GetEchoRequest(net::IOBuffer* io_buffer) { + // Copy the <version> into the io_buffer starting from the kVersionStart + // position. + std::string version = base::StringPrintf("%02d", kVersionNumber); + char* buffer = io_buffer->data() + kVersionStart; + DCHECK(kVersionLength == version.length()); + memcpy(buffer, version.c_str(), kVersionLength); + + // Get the <payload> from the |stream_| and copy it into io_buffer starting + // from the kPayloadStart position. + buffer = io_buffer->data() + kPayloadStart; + stream_.GetBytes(buffer, load_size_); + + // Calculate the <checksum> of the <payload>. + uint32 sum = 0; + for (uint32 i = 0; i < load_size_; ++i) + sum += buffer[i]; + + // Copy the <checksum> into the io_buffer starting from the kChecksumStart + // position. + std::string checksum = base::StringPrintf("%010d", sum); + buffer = io_buffer->data() + kChecksumStart; + DCHECK(kChecksumLength == checksum.length()); + memcpy(buffer, checksum.c_str(), kChecksumLength); + + // Copy the size of the <payload> into the io_buffer starting from the + // kPayloadSizeStart position. + buffer = io_buffer->data() + kPayloadSizeStart; + std::string payload_size = base::StringPrintf("%07d", load_size_); + DCHECK(kPayloadSizeLength == payload_size.length()); + memcpy(buffer, payload_size.c_str(), kPayloadSizeLength); +} + +bool NetworkStats::VerifyBytes() { + // If the "echo response" doesn't have enough bytes, then return false. + if (encoded_message_.length() < kEncodedPayloadStart) + return false; + + // Extract the |key| from the "echo response". + std::string key_string = encoded_message_.substr(kKeyStart, kKeyLength); + const char* key = key_string.c_str(); + int key_value = atoi(key); + if (key_value < kKeyMinValue || key_value > kKeyMaxValue) + return false; + + std::string encoded_payload = + encoded_message_.substr(kEncodedPayloadStart); + const char* encoded_data = encoded_payload.c_str(); + uint32 message_length = encoded_payload.length(); + message_length = std::min(message_length, kMaxMessage); + // We should get back all the data we had sent. + if (message_length != load_size_) + return false; + + // Decrypt the data by looping through the |encoded_data| and XOR each byte + // with the |key| to get the decoded byte. Append the decoded byte to the + // |decoded_data|. + char decoded_data[kMaxMessage + 1]; + for (uint32 data_index = 0, key_index = 0; + data_index < message_length; + ++data_index) { + char encoded_byte = encoded_data[data_index]; + char key_byte = key[key_index]; + char decoded_byte = encoded_byte ^ key_byte; + decoded_data[data_index] = decoded_byte; + key_index = (key_index + 1) % kKeyLength; + } + + return stream_.VerifyBytes(decoded_data, message_length); +} + // UDPStatsClient methods and members. UDPStatsClient::UDPStatsClient() : NetworkStats() { @@ -201,7 +344,7 @@ UDPStatsClient::~UDPStatsClient() { bool UDPStatsClient::Start(const std::string& ip_str, int port, - int bytes_to_send, + uint32 bytes_to_send, net::CompletionCallback* finished_callback) { DCHECK(port); DCHECK(bytes_to_send); // We should have data to send. @@ -220,13 +363,25 @@ bool UDPStatsClient::Start(const std::string& ip_str, net::RandIntCallback(), NULL, net::NetLog::Source()); - DCHECK(udp_socket); + if (!udp_socket) { + Finish(SOCKET_CREATE_FAILED, net::ERR_INVALID_ARGUMENT); + return false; + } set_socket(udp_socket); int rv = udp_socket->Connect(server_address); return DoStart(rv); } +bool UDPStatsClient::ReadComplete(int result) { + DCHECK_NE(net::ERR_IO_PENDING, result); + if (result <= 0) { + Finish(READ_FAILED, result); + return true; + } + return NetworkStats::ReadComplete(result); +} + void UDPStatsClient::Finish(Status status, int result) { base::TimeDelta duration = base::TimeTicks::Now() - start_time(); if (load_size() == kSmallTestBytesToSend) { @@ -272,7 +427,7 @@ TCPStatsClient::~TCPStatsClient() { bool TCPStatsClient::Start(net::HostResolver* host_resolver, const net::HostPortPair& server_host_port_pair, - int bytes_to_send, + uint32 bytes_to_send, net::CompletionCallback* finished_callback) { DCHECK(bytes_to_send); // We should have data to send. @@ -301,7 +456,10 @@ bool TCPStatsClient::DoConnect(int result) { net::TCPClientSocket* tcp_socket = new net::TCPClientSocket(addresses_, NULL, net::NetLog::Source()); - DCHECK(tcp_socket); + if (!tcp_socket) { + Finish(SOCKET_CREATE_FAILED, net::ERR_INVALID_ARGUMENT); + return false; + } set_socket(tcp_socket); int rv = tcp_socket->Connect(&connect_callback_); @@ -315,6 +473,15 @@ void TCPStatsClient::OnConnectComplete(int result) { DoStart(result); } +bool TCPStatsClient::ReadComplete(int result) { + DCHECK_NE(net::ERR_IO_PENDING, result); + if (result < 0) { + Finish(READ_FAILED, result); + return true; + } + return NetworkStats::ReadComplete(result); +} + void TCPStatsClient::Finish(Status status, int result) { base::TimeDelta duration = base::TimeTicks::Now() - start_time(); if (load_size() == kSmallTestBytesToSend) { diff --git a/chrome/browser/net/network_stats.h b/chrome/browser/net/network_stats.h index a7c2936..bbc449f 100644 --- a/chrome/browser/net/network_stats.h +++ b/chrome/browser/net/network_stats.h @@ -11,6 +11,7 @@ #include "base/basictypes.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" +#include "base/string_util.h" #include "base/time.h" #include "chrome/browser/io_thread.h" #include "net/base/address_list.h" @@ -34,15 +35,30 @@ namespace chrome_browser_net { // c) What is the latency for UDP and TCP. // d) If connectivity failed, at what stage (Connect or Write or Read) did it // fail? +// +// The following is the overview of the echo message protocol. +// +// We send the "echo request" to the TCP/UDP servers in the following format: +// <version><checksum><payload_size><payload>. <version> is the version number +// of the "echo request". <checksum> is the checksum of the <payload>. +// <payload_size> specifies the number of bytes in the <payload>. +// +// TCP/UDP servers respond to the "echo request" by returning "echo response". +// "echo response" is of the format: +// "<version><checksum><payload_size><key><encoded_payload>". <payload_size> +// specifies the number of bytes in the <encoded_payload>. <key> is used to +// decode the <encoded_payload>. class NetworkStats { public: enum Status { // Used in HISTOGRAM_ENUMERATION. SUCCESS, // Successfully received bytes from the server. IP_STRING_PARSE_FAILED, // Parsing of IP string failed. + SOCKET_CREATE_FAILED, // Socket creation failed. RESOLVE_FAILED, // Host resolution failed. CONNECT_FAILED, // Connection to the server failed. WRITE_FAILED, // Sending an echo message to the server failed. + READ_TIMED_OUT, // Reading the reply from the server timed out. READ_FAILED, // Reading the reply from the server failed. READ_VERIFY_FAILED, // Verification of data failed. STATUS_MAX, // Bounding value. @@ -57,7 +73,7 @@ class NetworkStats { // Initializes |finished_callback_| and the number of bytes to send to the // server. |finished_callback| is called when we are done with the test. // |finished_callback| is mainly useful for unittests. - void Initialize(int bytes_to_send, + void Initialize(uint32 bytes_to_send, net::CompletionCallback* finished_callback); // This method is called after socket connection is completed. It will send @@ -74,8 +90,12 @@ class NetworkStats { // to indicate that the test has finished. void DoFinishCallback(int result); + // Verifies the data and calls Finish() if there is an error or if all bytes + // are read. Returns true if Finish() is called otherwise returns false. + virtual bool ReadComplete(int result); + // Returns the number of bytes to be sent to the |server|. - int load_size() const { return load_size_; } + uint32 load_size() const { return load_size_; } // Helper methods to get and set |socket_|. net::Socket* socket() { return socket_.get(); } @@ -85,10 +105,6 @@ class NetworkStats { base::TimeTicks start_time() const { return start_time_; } private: - // Verifies the data and calls Finish() if there is an error or if all bytes - // are read. Returns true if Finish() is called otherwise returns false. - bool ReadComplete(int result); - // Callbacks when an internal IO is completed. void OnReadComplete(int result); void OnWriteComplete(int result); @@ -99,6 +115,21 @@ class NetworkStats { // Sends data to server until an error occurs. int SendData(); + // We set a timeout for responses from the echo servers. + void StartReadDataTimer(int milliseconds); + void OnReadDataTimeout(); // Called when the ReadData Timer fires. + + // Fills the |io_buffer| with the "echo request" message. This gets the + // <payload> from |stream_| and calculates the <checksum> of the <payload> and + // returns the "echo request" that has <version>, <checksum>, <payload_size> + // and <payload>. + void GetEchoRequest(net::IOBuffer* io_buffer); + + // This method parses the "echo response" message in the |encoded_message_| to + // verify that the <payload> is same as what we had sent in "echo request" + // message. + bool VerifyBytes(); + // The socket handle for this session. scoped_ptr<net::Socket> socket_; @@ -109,10 +140,13 @@ class NetworkStats { scoped_refptr<net::DrainableIOBuffer> write_buffer_; // Some counters for the session. - int load_size_; + uint32 load_size_; int bytes_to_read_; int bytes_to_send_; + // The encoded message read from the server. + std::string encoded_message_; + // |stream_| is used to generate data to be sent to the server and it is also // used to verify the data received from the server. net::TestDataStream stream_; @@ -130,6 +164,9 @@ class NetworkStats { // The time when the session was started. base::TimeTicks start_time_; + + // We use this factory to create timeout tasks for socket's ReadData. + ScopedRunnableMethodFactory<NetworkStats> timers_factory_; }; class UDPStatsClient : public NetworkStats { @@ -148,7 +185,7 @@ class UDPStatsClient : public NetworkStats { // Returns true if successful in starting the client. bool Start(const std::string& ip_str, int port, - int bytes_to_send, + uint32 bytes_to_send, net::CompletionCallback* callback); protected: @@ -158,6 +195,11 @@ class UDPStatsClient : public NetworkStats { // Collects stats for UDP connectivity. This is called when all the data from // server is read or when there is a failure during connect/read/write. virtual void Finish(Status status, int result); + + // This method calls NetworkStats::ReadComplete() to verify the data and calls + // Finish() if there is an error or if read callback didn't return any data + // (|result| is less than or equal to 0). + virtual bool ReadComplete(int result); }; class TCPStatsClient : public NetworkStats { @@ -176,7 +218,7 @@ class TCPStatsClient : public NetworkStats { // Returns true if successful in starting the client. bool Start(net::HostResolver* host_resolver, const net::HostPortPair& server, - int bytes_to_send, + uint32 bytes_to_send, net::CompletionCallback* callback); protected: @@ -187,6 +229,10 @@ class TCPStatsClient : public NetworkStats { // server is read or when there is a failure during connect/read/write. virtual void Finish(Status status, int result); + // This method calls NetworkStats::ReadComplete() to verify the data and calls + // Finish() if there is an error (|result| is less than 0). + virtual bool ReadComplete(int result); + private: // Callback that is called when host resolution is completed. void OnResolveComplete(int result); |