From 92aea484adc997768974c8687cba5c6202256278 Mon Sep 17 00:00:00 2001
From: "ukai@chromium.org"
 <ukai@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>
Date: Fri, 23 Oct 2009 02:25:38 +0000
Subject: Add unittest for WebSocket::ProcessFrameData.

It catches a bug fixed in r29760.
Also find another bug and fixed in this change.
(same bug found in WebKit code. reported as
 http://bugs.webkit.org/show_bug.cgi?id=30668)

BUG=none
TEST=net_unittests passes

Review URL: http://codereview.chromium.org/307036

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@29864 0039d316-1c4b-4281-b951-d872f2087c98
---
 net/websockets/websocket.cc          |  32 ++++--
 net/websockets/websocket.h           |   5 +
 net/websockets/websocket_unittest.cc | 217 ++++++++++++++++++++++++++---------
 3 files changed, 189 insertions(+), 65 deletions(-)

(limited to 'net/websockets')

diff --git a/net/websockets/websocket.cc b/net/websockets/websocket.cc
index 9b2b142..a00f6bd 100644
--- a/net/websockets/websocket.cc
+++ b/net/websockets/websocket.cc
@@ -48,7 +48,7 @@ WebSocket::WebSocket(Request* request, WebSocketDelegate* delegate)
 }
 
 WebSocket::~WebSocket() {
-  DCHECK(!delegate_);
+  DCHECK(ready_state_ == INITIALIZED || !delegate_);
   DCHECK(!socket_stream_);
 }
 
@@ -131,17 +131,7 @@ void WebSocket::OnSentData(SocketStream* socket_stream, int amount_sent) {
 void WebSocket::OnReceivedData(SocketStream* socket_stream,
                                const char* data, int len) {
   DCHECK(socket_stream == socket_stream_);
-  DCHECK(current_read_buf_);
-  // Check if |current_read_buf_| has enough space to store |len| of |data|.
-  if (len >= current_read_buf_->RemainingCapacity()) {
-    current_read_buf_->set_capacity(
-        current_read_buf_->offset() + len);
-  }
-
-  DCHECK(current_read_buf_->RemainingCapacity() >= len);
-  memcpy(current_read_buf_->data(), data, len);
-  current_read_buf_->set_offset(current_read_buf_->offset() + len);
-
+  AddToReadBuffer(data, len);
   origin_loop_->PostTask(FROM_HERE,
                          NewRunnableMethod(this, &WebSocket::DoReceivedData));
 }
@@ -389,6 +379,8 @@ void WebSocket::ProcessFrameData() {
       if (p + length < end) {
         p += length;
         next_frame = p;
+      } else {
+        break;
       }
     } else {
       const char* msg_start = p;
@@ -405,7 +397,23 @@ void WebSocket::ProcessFrameData() {
   SkipReadBuffer(next_frame - start_frame);
 }
 
+void WebSocket::AddToReadBuffer(const char* data, int len) {
+  DCHECK(current_read_buf_);
+  // Check if |current_read_buf_| has enough space to store |len| of |data|.
+  if (len >= current_read_buf_->RemainingCapacity()) {
+    current_read_buf_->set_capacity(
+        current_read_buf_->offset() + len);
+  }
+
+  DCHECK(current_read_buf_->RemainingCapacity() >= len);
+  memcpy(current_read_buf_->data(), data, len);
+  current_read_buf_->set_offset(current_read_buf_->offset() + len);
+}
+
 void WebSocket::SkipReadBuffer(int len) {
+  if (len == 0)
+    return;
+  DCHECK_GT(len, 0);
   read_consumed_len_ += len;
   int remaining = current_read_buf_->offset() - read_consumed_len_;
   DCHECK_GE(remaining, 0);
diff --git a/net/websockets/websocket.h b/net/websockets/websocket.h
index 1294382..2279681 100644
--- a/net/websockets/websocket.h
+++ b/net/websockets/websocket.h
@@ -139,6 +139,8 @@ class WebSocket : public base::RefCountedThreadSafe<WebSocket>,
   };
   typedef std::deque< scoped_refptr<IOBufferWithSize> > PendingDataQueue;
 
+  friend class WebSocketTest;
+
   friend class base::RefCountedThreadSafe<WebSocket>;
   virtual ~WebSocket();
 
@@ -172,6 +174,9 @@ class WebSocket : public base::RefCountedThreadSafe<WebSocket>,
   // Processes frame data in |current_read_buf_|.
   void ProcessFrameData();
 
+  // Adds |len| bytes of |data| to |current_read_buf_|.
+  void AddToReadBuffer(const char* data, int len);
+
   // Skips |len| bytes in |current_read_buf_|.
   void SkipReadBuffer(int len);
 
diff --git a/net/websockets/websocket_unittest.cc b/net/websockets/websocket_unittest.cc
index cdb9aa9..907f62a 100644
--- a/net/websockets/websocket_unittest.cc
+++ b/net/websockets/websocket_unittest.cc
@@ -7,6 +7,7 @@
 
 #include "base/task.h"
 #include "net/base/completion_callback.h"
+#include "net/base/io_buffer.h"
 #include "net/base/mock_host_resolver.h"
 #include "net/base/test_completion_callback.h"
 #include "net/socket/socket_test_util.h"
@@ -92,41 +93,70 @@ class WebSocketEventRecorder : public net::WebSocketDelegate {
   DISALLOW_COPY_AND_ASSIGN(WebSocketEventRecorder);
 };
 
+namespace net {
+
 class WebSocketTest : public PlatformTest {
+ protected:
+  void InitReadBuf(WebSocket* websocket) {
+    // Set up |current_read_buf_|.
+    websocket->current_read_buf_ = new GrowableIOBuffer();
+  }
+  void SetReadConsumed(WebSocket* websocket, int consumed) {
+    websocket->read_consumed_len_ = consumed;
+  }
+  void AddToReadBuf(WebSocket* websocket, const char* data, int len) {
+    websocket->AddToReadBuffer(data, len);
+  }
+
+  void TestProcessFrameData(WebSocket* websocket,
+                            const char* expected_remaining_data,
+                            int expected_remaining_len) {
+    websocket->ProcessFrameData();
+
+    const char* actual_remaining_data =
+        websocket->current_read_buf_->StartOfBuffer()
+        + websocket->read_consumed_len_;
+    int actual_remaining_len =
+        websocket->current_read_buf_->offset() - websocket->read_consumed_len_;
+
+    EXPECT_EQ(expected_remaining_len, actual_remaining_len);
+    EXPECT_TRUE(!memcmp(expected_remaining_data, actual_remaining_data,
+                        expected_remaining_len));
+  }
 };
 
 TEST_F(WebSocketTest, Connect) {
-  net::MockClientSocketFactory mock_socket_factory;
-  net::MockRead data_reads[] = {
-    net::MockRead("HTTP/1.1 101 Web Socket Protocol\r\n"
-                  "Upgrade: WebSocket\r\n"
-                  "Connection: Upgrade\r\n"
-                  "WebSocket-Origin: http://example.com\r\n"
-                  "WebSocket-Location: ws://example.com/demo\r\n"
-                  "WebSocket-Protocol: sample\r\n"
-                  "\r\n"),
+  MockClientSocketFactory mock_socket_factory;
+  MockRead data_reads[] = {
+    MockRead("HTTP/1.1 101 Web Socket Protocol\r\n"
+             "Upgrade: WebSocket\r\n"
+             "Connection: Upgrade\r\n"
+             "WebSocket-Origin: http://example.com\r\n"
+             "WebSocket-Location: ws://example.com/demo\r\n"
+             "WebSocket-Protocol: sample\r\n"
+             "\r\n"),
     // Server doesn't close the connection after handshake.
-    net::MockRead(true, net::ERR_IO_PENDING),
+    MockRead(true, ERR_IO_PENDING),
   };
-  net::MockWrite data_writes[] = {
-    net::MockWrite("GET /demo HTTP/1.1\r\n"
-                   "Upgrade: WebSocket\r\n"
-                   "Connection: Upgrade\r\n"
-                   "Host: example.com\r\n"
-                   "Origin: http://example.com\r\n"
-                   "WebSocket-Protocol: sample\r\n"
-                   "\r\n"),
+  MockWrite data_writes[] = {
+    MockWrite("GET /demo HTTP/1.1\r\n"
+              "Upgrade: WebSocket\r\n"
+              "Connection: Upgrade\r\n"
+              "Host: example.com\r\n"
+              "Origin: http://example.com\r\n"
+              "WebSocket-Protocol: sample\r\n"
+              "\r\n"),
   };
-  net::StaticMockSocket data(data_reads, data_writes);
+  StaticMockSocket data(data_reads, data_writes);
   mock_socket_factory.AddMockSocket(&data);
 
-  net::WebSocket::Request* request(
-      new net::WebSocket::Request(GURL("ws://example.com/demo"),
-                                  "sample",
-                                  "http://example.com",
-                                  "ws://example.com/demo",
-                                  new TestURLRequestContext()));
-  request->SetHostResolver(new net::MockHostResolver());
+  WebSocket::Request* request(
+      new WebSocket::Request(GURL("ws://example.com/demo"),
+                             "sample",
+                             "http://example.com",
+                             "ws://example.com/demo",
+                             new TestURLRequestContext()));
+  request->SetHostResolver(new MockHostResolver());
   request->SetClientSocketFactory(&mock_socket_factory);
 
   TestCompletionCallback callback;
@@ -136,10 +166,10 @@ TEST_F(WebSocketTest, Connect) {
   delegate->SetOnOpen(NewCallback(delegate.get(),
                                   &WebSocketEventRecorder::DoClose));
 
-  scoped_refptr<net::WebSocket> websocket(
-      new net::WebSocket(request, delegate.get()));
+  scoped_refptr<WebSocket> websocket(
+      new WebSocket(request, delegate.get()));
 
-  EXPECT_EQ(net::WebSocket::INITIALIZED, websocket->ready_state());
+  EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state());
   websocket->Connect();
 
   callback.WaitForResult();
@@ -152,41 +182,41 @@ TEST_F(WebSocketTest, Connect) {
 }
 
 TEST_F(WebSocketTest, ServerSentData) {
-  net::MockClientSocketFactory mock_socket_factory;
+  MockClientSocketFactory mock_socket_factory;
   static const char kMessage[] = "Hello";
   static const char kFrame[] = "\x00Hello\xff";
   static const int kFrameLen = sizeof(kFrame) - 1;
-  net::MockRead data_reads[] = {
-    net::MockRead("HTTP/1.1 101 Web Socket Protocol\r\n"
+  MockRead data_reads[] = {
+    MockRead("HTTP/1.1 101 Web Socket Protocol\r\n"
                   "Upgrade: WebSocket\r\n"
                   "Connection: Upgrade\r\n"
                   "WebSocket-Origin: http://example.com\r\n"
                   "WebSocket-Location: ws://example.com/demo\r\n"
                   "WebSocket-Protocol: sample\r\n"
                   "\r\n"),
-    net::MockRead(true, kFrame, kFrameLen),
+    MockRead(true, kFrame, kFrameLen),
     // Server doesn't close the connection after handshake.
-    net::MockRead(true, net::ERR_IO_PENDING),
+    MockRead(true, ERR_IO_PENDING),
   };
-  net::MockWrite data_writes[] = {
-    net::MockWrite("GET /demo HTTP/1.1\r\n"
-                   "Upgrade: WebSocket\r\n"
-                   "Connection: Upgrade\r\n"
-                   "Host: example.com\r\n"
-                   "Origin: http://example.com\r\n"
-                   "WebSocket-Protocol: sample\r\n"
-                   "\r\n"),
+  MockWrite data_writes[] = {
+    MockWrite("GET /demo HTTP/1.1\r\n"
+              "Upgrade: WebSocket\r\n"
+              "Connection: Upgrade\r\n"
+              "Host: example.com\r\n"
+              "Origin: http://example.com\r\n"
+              "WebSocket-Protocol: sample\r\n"
+              "\r\n"),
   };
-  net::StaticMockSocket data(data_reads, data_writes);
+  StaticMockSocket data(data_reads, data_writes);
   mock_socket_factory.AddMockSocket(&data);
 
-  net::WebSocket::Request* request(
-      new net::WebSocket::Request(GURL("ws://example.com/demo"),
-                                  "sample",
-                                  "http://example.com",
-                                  "ws://example.com/demo",
-                                  new TestURLRequestContext()));
-  request->SetHostResolver(new net::MockHostResolver());
+  WebSocket::Request* request(
+      new WebSocket::Request(GURL("ws://example.com/demo"),
+                             "sample",
+                             "http://example.com",
+                             "ws://example.com/demo",
+                             new TestURLRequestContext()));
+  request->SetHostResolver(new MockHostResolver());
   request->SetClientSocketFactory(&mock_socket_factory);
 
   TestCompletionCallback callback;
@@ -196,10 +226,10 @@ TEST_F(WebSocketTest, ServerSentData) {
   delegate->SetOnMessage(NewCallback(delegate.get(),
                                      &WebSocketEventRecorder::DoClose));
 
-  scoped_refptr<net::WebSocket> websocket(
-      new net::WebSocket(request, delegate.get()));
+  scoped_refptr<WebSocket> websocket(
+      new WebSocket(request, delegate.get()));
 
-  EXPECT_EQ(net::WebSocket::INITIALIZED, websocket->ready_state());
+  EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state());
   websocket->Connect();
 
   callback.WaitForResult();
@@ -212,3 +242,84 @@ TEST_F(WebSocketTest, ServerSentData) {
   EXPECT_EQ(kMessage, events[1].msg);
   EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[2].event_type);
 }
+
+TEST_F(WebSocketTest, ProcessFrameDataForLengthCalculation) {
+  WebSocket::Request* request(
+      new WebSocket::Request(GURL("ws://example.com/demo"),
+                             "sample",
+                             "http://example.com",
+                             "ws://example.com/demo",
+                             new TestURLRequestContext()));
+  TestCompletionCallback callback;
+  scoped_ptr<WebSocketEventRecorder> delegate(
+      new WebSocketEventRecorder(&callback));
+
+  scoped_refptr<WebSocket> websocket(
+      new WebSocket(request, delegate.get()));
+
+  // Frame data: skip length 1 ('x'), and try to skip length 129
+  // (1 * 128 + 1) bytes after second \x81, but buffer is too short to skip.
+  static const char kTestLengthFrame[] =
+      "\x80\x81x\x80\x81\x81\x01\x00unexpected data\xFF";
+  const int kTestLengthFrameLength = sizeof(kTestLengthFrame) - 1;
+  InitReadBuf(websocket.get());
+  AddToReadBuf(websocket.get(), kTestLengthFrame, kTestLengthFrameLength);
+  SetReadConsumed(websocket.get(), 0);
+
+  static const char kExpectedRemainingFrame[] =
+      "\x80\x81\x81\x01\x00unexpected data\xFF";
+  const int kExpectedRemainingLength = sizeof(kExpectedRemainingFrame) - 1;
+  TestProcessFrameData(websocket.get(),
+                       kExpectedRemainingFrame, kExpectedRemainingLength);
+  // No onmessage event expected.
+  const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
+  EXPECT_EQ(0U, events.size());
+}
+
+TEST_F(WebSocketTest, ProcessFrameDataForUnterminatedString) {
+  WebSocket::Request* request(
+      new WebSocket::Request(GURL("ws://example.com/demo"),
+                             "sample",
+                             "http://example.com",
+                             "ws://example.com/demo",
+                             new TestURLRequestContext()));
+  TestCompletionCallback callback;
+  scoped_ptr<WebSocketEventRecorder> delegate(
+      new WebSocketEventRecorder(&callback));
+
+  scoped_refptr<WebSocket> websocket(
+      new WebSocket(request, delegate.get()));
+
+  static const char kTestUnterminatedFrame[] =
+      "\x00unterminated frame";
+  const int kTestUnterminatedFrameLength = sizeof(kTestUnterminatedFrame) - 1;
+  InitReadBuf(websocket.get());
+  AddToReadBuf(websocket.get(), kTestUnterminatedFrame,
+               kTestUnterminatedFrameLength);
+  SetReadConsumed(websocket.get(), 0);
+  TestProcessFrameData(websocket.get(),
+                       kTestUnterminatedFrame, kTestUnterminatedFrameLength);
+  {
+    // No onmessage event expected.
+    const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
+    EXPECT_EQ(0U, events.size());
+  }
+
+  static const char kTestTerminateFrame[] = " is terminated in next read\xff";
+  const int kTestTerminateFrameLength = sizeof(kTestTerminateFrame) - 1;
+  AddToReadBuf(websocket.get(), kTestTerminateFrame,
+               kTestTerminateFrameLength);
+  TestProcessFrameData(websocket.get(), "", 0);
+
+  static const char kExpectedMsg[] =
+      "unterminated frame is terminated in next read";
+  {
+    const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
+    EXPECT_EQ(1U, events.size());
+
+    EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[0].event_type);
+    EXPECT_EQ(kExpectedMsg, events[0].msg);
+  }
+}
+
+}  // namespace net
-- 
cgit v1.1