summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--net/websockets/websocket.cc32
-rw-r--r--net/websockets/websocket.h5
-rw-r--r--net/websockets/websocket_unittest.cc217
3 files changed, 189 insertions, 65 deletions
diff --git a/net/websockets/websocket.cc b/net/websockets/websocket.cc
index 9b2b142..a00f6bd 100644
--- a/net/websockets/websocket.cc
+++ b/net/websockets/websocket.cc
@@ -48,7 +48,7 @@ WebSocket::WebSocket(Request* request, WebSocketDelegate* delegate)
}
WebSocket::~WebSocket() {
- DCHECK(!delegate_);
+ DCHECK(ready_state_ == INITIALIZED || !delegate_);
DCHECK(!socket_stream_);
}
@@ -131,17 +131,7 @@ void WebSocket::OnSentData(SocketStream* socket_stream, int amount_sent) {
void WebSocket::OnReceivedData(SocketStream* socket_stream,
const char* data, int len) {
DCHECK(socket_stream == socket_stream_);
- DCHECK(current_read_buf_);
- // Check if |current_read_buf_| has enough space to store |len| of |data|.
- if (len >= current_read_buf_->RemainingCapacity()) {
- current_read_buf_->set_capacity(
- current_read_buf_->offset() + len);
- }
-
- DCHECK(current_read_buf_->RemainingCapacity() >= len);
- memcpy(current_read_buf_->data(), data, len);
- current_read_buf_->set_offset(current_read_buf_->offset() + len);
-
+ AddToReadBuffer(data, len);
origin_loop_->PostTask(FROM_HERE,
NewRunnableMethod(this, &WebSocket::DoReceivedData));
}
@@ -389,6 +379,8 @@ void WebSocket::ProcessFrameData() {
if (p + length < end) {
p += length;
next_frame = p;
+ } else {
+ break;
}
} else {
const char* msg_start = p;
@@ -405,7 +397,23 @@ void WebSocket::ProcessFrameData() {
SkipReadBuffer(next_frame - start_frame);
}
+void WebSocket::AddToReadBuffer(const char* data, int len) {
+ DCHECK(current_read_buf_);
+ // Check if |current_read_buf_| has enough space to store |len| of |data|.
+ if (len >= current_read_buf_->RemainingCapacity()) {
+ current_read_buf_->set_capacity(
+ current_read_buf_->offset() + len);
+ }
+
+ DCHECK(current_read_buf_->RemainingCapacity() >= len);
+ memcpy(current_read_buf_->data(), data, len);
+ current_read_buf_->set_offset(current_read_buf_->offset() + len);
+}
+
void WebSocket::SkipReadBuffer(int len) {
+ if (len == 0)
+ return;
+ DCHECK_GT(len, 0);
read_consumed_len_ += len;
int remaining = current_read_buf_->offset() - read_consumed_len_;
DCHECK_GE(remaining, 0);
diff --git a/net/websockets/websocket.h b/net/websockets/websocket.h
index 1294382..2279681 100644
--- a/net/websockets/websocket.h
+++ b/net/websockets/websocket.h
@@ -139,6 +139,8 @@ class WebSocket : public base::RefCountedThreadSafe<WebSocket>,
};
typedef std::deque< scoped_refptr<IOBufferWithSize> > PendingDataQueue;
+ friend class WebSocketTest;
+
friend class base::RefCountedThreadSafe<WebSocket>;
virtual ~WebSocket();
@@ -172,6 +174,9 @@ class WebSocket : public base::RefCountedThreadSafe<WebSocket>,
// Processes frame data in |current_read_buf_|.
void ProcessFrameData();
+ // Adds |len| bytes of |data| to |current_read_buf_|.
+ void AddToReadBuffer(const char* data, int len);
+
// Skips |len| bytes in |current_read_buf_|.
void SkipReadBuffer(int len);
diff --git a/net/websockets/websocket_unittest.cc b/net/websockets/websocket_unittest.cc
index cdb9aa9..907f62a 100644
--- a/net/websockets/websocket_unittest.cc
+++ b/net/websockets/websocket_unittest.cc
@@ -7,6 +7,7 @@
#include "base/task.h"
#include "net/base/completion_callback.h"
+#include "net/base/io_buffer.h"
#include "net/base/mock_host_resolver.h"
#include "net/base/test_completion_callback.h"
#include "net/socket/socket_test_util.h"
@@ -92,41 +93,70 @@ class WebSocketEventRecorder : public net::WebSocketDelegate {
DISALLOW_COPY_AND_ASSIGN(WebSocketEventRecorder);
};
+namespace net {
+
class WebSocketTest : public PlatformTest {
+ protected:
+ void InitReadBuf(WebSocket* websocket) {
+ // Set up |current_read_buf_|.
+ websocket->current_read_buf_ = new GrowableIOBuffer();
+ }
+ void SetReadConsumed(WebSocket* websocket, int consumed) {
+ websocket->read_consumed_len_ = consumed;
+ }
+ void AddToReadBuf(WebSocket* websocket, const char* data, int len) {
+ websocket->AddToReadBuffer(data, len);
+ }
+
+ void TestProcessFrameData(WebSocket* websocket,
+ const char* expected_remaining_data,
+ int expected_remaining_len) {
+ websocket->ProcessFrameData();
+
+ const char* actual_remaining_data =
+ websocket->current_read_buf_->StartOfBuffer()
+ + websocket->read_consumed_len_;
+ int actual_remaining_len =
+ websocket->current_read_buf_->offset() - websocket->read_consumed_len_;
+
+ EXPECT_EQ(expected_remaining_len, actual_remaining_len);
+ EXPECT_TRUE(!memcmp(expected_remaining_data, actual_remaining_data,
+ expected_remaining_len));
+ }
};
TEST_F(WebSocketTest, Connect) {
- net::MockClientSocketFactory mock_socket_factory;
- net::MockRead data_reads[] = {
- net::MockRead("HTTP/1.1 101 Web Socket Protocol\r\n"
- "Upgrade: WebSocket\r\n"
- "Connection: Upgrade\r\n"
- "WebSocket-Origin: http://example.com\r\n"
- "WebSocket-Location: ws://example.com/demo\r\n"
- "WebSocket-Protocol: sample\r\n"
- "\r\n"),
+ MockClientSocketFactory mock_socket_factory;
+ MockRead data_reads[] = {
+ MockRead("HTTP/1.1 101 Web Socket Protocol\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "WebSocket-Origin: http://example.com\r\n"
+ "WebSocket-Location: ws://example.com/demo\r\n"
+ "WebSocket-Protocol: sample\r\n"
+ "\r\n"),
// Server doesn't close the connection after handshake.
- net::MockRead(true, net::ERR_IO_PENDING),
+ MockRead(true, ERR_IO_PENDING),
};
- net::MockWrite data_writes[] = {
- net::MockWrite("GET /demo HTTP/1.1\r\n"
- "Upgrade: WebSocket\r\n"
- "Connection: Upgrade\r\n"
- "Host: example.com\r\n"
- "Origin: http://example.com\r\n"
- "WebSocket-Protocol: sample\r\n"
- "\r\n"),
+ MockWrite data_writes[] = {
+ MockWrite("GET /demo HTTP/1.1\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Host: example.com\r\n"
+ "Origin: http://example.com\r\n"
+ "WebSocket-Protocol: sample\r\n"
+ "\r\n"),
};
- net::StaticMockSocket data(data_reads, data_writes);
+ StaticMockSocket data(data_reads, data_writes);
mock_socket_factory.AddMockSocket(&data);
- net::WebSocket::Request* request(
- new net::WebSocket::Request(GURL("ws://example.com/demo"),
- "sample",
- "http://example.com",
- "ws://example.com/demo",
- new TestURLRequestContext()));
- request->SetHostResolver(new net::MockHostResolver());
+ WebSocket::Request* request(
+ new WebSocket::Request(GURL("ws://example.com/demo"),
+ "sample",
+ "http://example.com",
+ "ws://example.com/demo",
+ new TestURLRequestContext()));
+ request->SetHostResolver(new MockHostResolver());
request->SetClientSocketFactory(&mock_socket_factory);
TestCompletionCallback callback;
@@ -136,10 +166,10 @@ TEST_F(WebSocketTest, Connect) {
delegate->SetOnOpen(NewCallback(delegate.get(),
&WebSocketEventRecorder::DoClose));
- scoped_refptr<net::WebSocket> websocket(
- new net::WebSocket(request, delegate.get()));
+ scoped_refptr<WebSocket> websocket(
+ new WebSocket(request, delegate.get()));
- EXPECT_EQ(net::WebSocket::INITIALIZED, websocket->ready_state());
+ EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state());
websocket->Connect();
callback.WaitForResult();
@@ -152,41 +182,41 @@ TEST_F(WebSocketTest, Connect) {
}
TEST_F(WebSocketTest, ServerSentData) {
- net::MockClientSocketFactory mock_socket_factory;
+ MockClientSocketFactory mock_socket_factory;
static const char kMessage[] = "Hello";
static const char kFrame[] = "\x00Hello\xff";
static const int kFrameLen = sizeof(kFrame) - 1;
- net::MockRead data_reads[] = {
- net::MockRead("HTTP/1.1 101 Web Socket Protocol\r\n"
+ MockRead data_reads[] = {
+ MockRead("HTTP/1.1 101 Web Socket Protocol\r\n"
"Upgrade: WebSocket\r\n"
"Connection: Upgrade\r\n"
"WebSocket-Origin: http://example.com\r\n"
"WebSocket-Location: ws://example.com/demo\r\n"
"WebSocket-Protocol: sample\r\n"
"\r\n"),
- net::MockRead(true, kFrame, kFrameLen),
+ MockRead(true, kFrame, kFrameLen),
// Server doesn't close the connection after handshake.
- net::MockRead(true, net::ERR_IO_PENDING),
+ MockRead(true, ERR_IO_PENDING),
};
- net::MockWrite data_writes[] = {
- net::MockWrite("GET /demo HTTP/1.1\r\n"
- "Upgrade: WebSocket\r\n"
- "Connection: Upgrade\r\n"
- "Host: example.com\r\n"
- "Origin: http://example.com\r\n"
- "WebSocket-Protocol: sample\r\n"
- "\r\n"),
+ MockWrite data_writes[] = {
+ MockWrite("GET /demo HTTP/1.1\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Host: example.com\r\n"
+ "Origin: http://example.com\r\n"
+ "WebSocket-Protocol: sample\r\n"
+ "\r\n"),
};
- net::StaticMockSocket data(data_reads, data_writes);
+ StaticMockSocket data(data_reads, data_writes);
mock_socket_factory.AddMockSocket(&data);
- net::WebSocket::Request* request(
- new net::WebSocket::Request(GURL("ws://example.com/demo"),
- "sample",
- "http://example.com",
- "ws://example.com/demo",
- new TestURLRequestContext()));
- request->SetHostResolver(new net::MockHostResolver());
+ WebSocket::Request* request(
+ new WebSocket::Request(GURL("ws://example.com/demo"),
+ "sample",
+ "http://example.com",
+ "ws://example.com/demo",
+ new TestURLRequestContext()));
+ request->SetHostResolver(new MockHostResolver());
request->SetClientSocketFactory(&mock_socket_factory);
TestCompletionCallback callback;
@@ -196,10 +226,10 @@ TEST_F(WebSocketTest, ServerSentData) {
delegate->SetOnMessage(NewCallback(delegate.get(),
&WebSocketEventRecorder::DoClose));
- scoped_refptr<net::WebSocket> websocket(
- new net::WebSocket(request, delegate.get()));
+ scoped_refptr<WebSocket> websocket(
+ new WebSocket(request, delegate.get()));
- EXPECT_EQ(net::WebSocket::INITIALIZED, websocket->ready_state());
+ EXPECT_EQ(WebSocket::INITIALIZED, websocket->ready_state());
websocket->Connect();
callback.WaitForResult();
@@ -212,3 +242,84 @@ TEST_F(WebSocketTest, ServerSentData) {
EXPECT_EQ(kMessage, events[1].msg);
EXPECT_EQ(WebSocketEvent::EVENT_CLOSE, events[2].event_type);
}
+
+TEST_F(WebSocketTest, ProcessFrameDataForLengthCalculation) {
+ WebSocket::Request* request(
+ new WebSocket::Request(GURL("ws://example.com/demo"),
+ "sample",
+ "http://example.com",
+ "ws://example.com/demo",
+ new TestURLRequestContext()));
+ TestCompletionCallback callback;
+ scoped_ptr<WebSocketEventRecorder> delegate(
+ new WebSocketEventRecorder(&callback));
+
+ scoped_refptr<WebSocket> websocket(
+ new WebSocket(request, delegate.get()));
+
+ // Frame data: skip length 1 ('x'), and try to skip length 129
+ // (1 * 128 + 1) bytes after second \x81, but buffer is too short to skip.
+ static const char kTestLengthFrame[] =
+ "\x80\x81x\x80\x81\x81\x01\x00unexpected data\xFF";
+ const int kTestLengthFrameLength = sizeof(kTestLengthFrame) - 1;
+ InitReadBuf(websocket.get());
+ AddToReadBuf(websocket.get(), kTestLengthFrame, kTestLengthFrameLength);
+ SetReadConsumed(websocket.get(), 0);
+
+ static const char kExpectedRemainingFrame[] =
+ "\x80\x81\x81\x01\x00unexpected data\xFF";
+ const int kExpectedRemainingLength = sizeof(kExpectedRemainingFrame) - 1;
+ TestProcessFrameData(websocket.get(),
+ kExpectedRemainingFrame, kExpectedRemainingLength);
+ // No onmessage event expected.
+ const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
+ EXPECT_EQ(0U, events.size());
+}
+
+TEST_F(WebSocketTest, ProcessFrameDataForUnterminatedString) {
+ WebSocket::Request* request(
+ new WebSocket::Request(GURL("ws://example.com/demo"),
+ "sample",
+ "http://example.com",
+ "ws://example.com/demo",
+ new TestURLRequestContext()));
+ TestCompletionCallback callback;
+ scoped_ptr<WebSocketEventRecorder> delegate(
+ new WebSocketEventRecorder(&callback));
+
+ scoped_refptr<WebSocket> websocket(
+ new WebSocket(request, delegate.get()));
+
+ static const char kTestUnterminatedFrame[] =
+ "\x00unterminated frame";
+ const int kTestUnterminatedFrameLength = sizeof(kTestUnterminatedFrame) - 1;
+ InitReadBuf(websocket.get());
+ AddToReadBuf(websocket.get(), kTestUnterminatedFrame,
+ kTestUnterminatedFrameLength);
+ SetReadConsumed(websocket.get(), 0);
+ TestProcessFrameData(websocket.get(),
+ kTestUnterminatedFrame, kTestUnterminatedFrameLength);
+ {
+ // No onmessage event expected.
+ const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
+ EXPECT_EQ(0U, events.size());
+ }
+
+ static const char kTestTerminateFrame[] = " is terminated in next read\xff";
+ const int kTestTerminateFrameLength = sizeof(kTestTerminateFrame) - 1;
+ AddToReadBuf(websocket.get(), kTestTerminateFrame,
+ kTestTerminateFrameLength);
+ TestProcessFrameData(websocket.get(), "", 0);
+
+ static const char kExpectedMsg[] =
+ "unterminated frame is terminated in next read";
+ {
+ const std::vector<WebSocketEvent>& events = delegate->GetSeenEvents();
+ EXPECT_EQ(1U, events.size());
+
+ EXPECT_EQ(WebSocketEvent::EVENT_MESSAGE, events[0].event_type);
+ EXPECT_EQ(kExpectedMsg, events[0].msg);
+ }
+}
+
+} // namespace net