summaryrefslogtreecommitdiffstats
path: root/remoting/jingle_glue/ssl_socket_adapter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'remoting/jingle_glue/ssl_socket_adapter.cc')
-rw-r--r--remoting/jingle_glue/ssl_socket_adapter.cc184
1 files changed, 132 insertions, 52 deletions
diff --git a/remoting/jingle_glue/ssl_socket_adapter.cc b/remoting/jingle_glue/ssl_socket_adapter.cc
index 9e5ef86..f070c05 100644
--- a/remoting/jingle_glue/ssl_socket_adapter.cc
+++ b/remoting/jingle_glue/ssl_socket_adapter.cc
@@ -27,9 +27,8 @@ SSLSocketAdapter::SSLSocketAdapter(AsyncSocket* socket)
ignore_bad_cert_(false),
cert_verifier_(net::CertVerifier::CreateDefault()),
ssl_state_(SSLSTATE_NONE),
- read_state_(IOSTATE_NONE),
- write_state_(IOSTATE_NONE),
- data_transferred_(0) {
+ read_pending_(false),
+ write_pending_(false) {
transport_socket_ = new TransportSocket(socket, this);
}
@@ -83,64 +82,102 @@ int SSLSocketAdapter::BeginSSL() {
}
int SSLSocketAdapter::Send(const void* buf, size_t len) {
- if (ssl_state_ != SSLSTATE_CONNECTED) {
+ if (ssl_state_ == SSLSTATE_ERROR) {
+ SetError(EINVAL);
+ return -1;
+ }
+
+ if (ssl_state_ == SSLSTATE_NONE) {
+ // Propagate the call to underlying socket if SSL is not connected
+ // yet (connection is not encrypted until StartSSL() is called).
return AsyncSocketAdapter::Send(buf, len);
- } else {
- scoped_refptr<net::IOBuffer> transport_buf(new net::IOBuffer(len));
- memcpy(transport_buf->data(), buf, len);
+ }
- int result = ssl_socket_->Write(transport_buf, len,
- net::CompletionCallback());
- if (result == net::ERR_IO_PENDING) {
- SetError(EWOULDBLOCK);
- }
- transport_buf = NULL;
- return result;
+ if (write_pending_) {
+ SetError(EWOULDBLOCK);
+ return -1;
}
+
+ write_buffer_ = new net::DrainableIOBuffer(new net::IOBuffer(len), len);
+ memcpy(write_buffer_->data(), buf, len);
+
+ DoWrite();
+
+ return len;
}
int SSLSocketAdapter::Recv(void* buf, size_t len) {
switch (ssl_state_) {
- case SSLSTATE_NONE:
+ case SSLSTATE_NONE: {
return AsyncSocketAdapter::Recv(buf, len);
+ }
- case SSLSTATE_WAIT:
+ case SSLSTATE_WAIT: {
SetError(EWOULDBLOCK);
return -1;
+ }
- case SSLSTATE_CONNECTED:
- switch (read_state_) {
- case IOSTATE_NONE: {
- transport_buf_ = new net::IOBuffer(len);
- int result = ssl_socket_->Read(
- transport_buf_, len,
- base::Bind(&SSLSocketAdapter::OnRead, base::Unretained(this)));
- if (result >= 0) {
- memcpy(buf, transport_buf_->data(), len);
- }
-
- if (result == net::ERR_IO_PENDING) {
- read_state_ = IOSTATE_PENDING;
- SetError(EWOULDBLOCK);
- } else {
- if (result < 0) {
- SetError(result);
- VLOG(1) << "Socket error " << result;
- }
- transport_buf_ = NULL;
- }
- return result;
- }
- case IOSTATE_PENDING:
+ case SSLSTATE_CONNECTED: {
+ if (read_pending_) {
+ SetError(EWOULDBLOCK);
+ return -1;
+ }
+
+ int bytes_read = 0;
+
+ // Process any data we have left from the previous read.
+ if (read_buffer_) {
+ int size = std::min(read_buffer_->RemainingCapacity(),
+ static_cast<int>(len));
+ memcpy(buf, read_buffer_->data(), size);
+ read_buffer_->set_offset(read_buffer_->offset() + size);
+ if (!read_buffer_->RemainingCapacity())
+ read_buffer_ = NULL;
+
+ if (size == static_cast<int>(len))
+ return size;
+
+ // If we didn't fill the caller's buffer then dispatch a new
+ // Read() in case there's more data ready.
+ buf = reinterpret_cast<char*>(buf) + size;
+ len -= size;
+ bytes_read = size;
+ DCHECK(!read_buffer_);
+ }
+
+ // Dispatch a Read() request to the SSL layer.
+ read_buffer_ = new net::GrowableIOBuffer();
+ read_buffer_->SetCapacity(len);
+ int result = ssl_socket_->Read(
+ read_buffer_, len,
+ base::Bind(&SSLSocketAdapter::OnRead, base::Unretained(this)));
+ if (result >= 0)
+ memcpy(buf, read_buffer_->data(), len);
+
+ if (result == net::ERR_IO_PENDING) {
+ read_pending_ = true;
+ if (bytes_read) {
+ return bytes_read;
+ } else {
SetError(EWOULDBLOCK);
return -1;
+ }
+ }
- case IOSTATE_COMPLETE:
- memcpy(buf, transport_buf_->data(), len);
- transport_buf_ = NULL;
- read_state_ = IOSTATE_NONE;
- return data_transferred_;
+ if (result < 0) {
+ SetError(EINVAL);
+ ssl_state_ = SSLSTATE_ERROR;
+ LOG(ERROR) << "Error reading from SSL socket " << result;
+ return -1;
}
+ read_buffer_ = NULL;
+ return result + bytes_read;
+ }
+
+ case SSLSTATE_ERROR: {
+ SetError(EINVAL);
+ return -1;
+ }
}
NOTREACHED();
@@ -157,19 +194,62 @@ void SSLSocketAdapter::OnConnected(int result) {
}
void SSLSocketAdapter::OnRead(int result) {
- DCHECK(read_state_ == IOSTATE_PENDING);
- read_state_ = IOSTATE_COMPLETE;
- data_transferred_ = result;
+ DCHECK(read_pending_);
+ read_pending_ = false;
+ if (result > 0) {
+ DCHECK_GE(read_buffer_->capacity(), result);
+ read_buffer_->SetCapacity(result);
+ } else {
+ if (result < 0)
+ ssl_state_ = SSLSTATE_ERROR;
+ }
AsyncSocketAdapter::OnReadEvent(this);
}
-void SSLSocketAdapter::OnWrite(int result) {
- DCHECK(write_state_ == IOSTATE_PENDING);
- write_state_ = IOSTATE_COMPLETE;
- data_transferred_ = result;
+void SSLSocketAdapter::OnWritten(int result) {
+ DCHECK(write_pending_);
+ write_pending_ = false;
+ if (result >= 0) {
+ write_buffer_->DidConsume(result);
+ if (!write_buffer_->BytesRemaining()) {
+ write_buffer_ = NULL;
+ } else {
+ DoWrite();
+ }
+ } else {
+ ssl_state_ = SSLSTATE_ERROR;
+ }
AsyncSocketAdapter::OnWriteEvent(this);
}
+void SSLSocketAdapter::DoWrite() {
+ DCHECK_GT(write_buffer_->BytesRemaining(), 0);
+ DCHECK(!write_pending_);
+
+ while (true) {
+ int result = ssl_socket_->Write(
+ write_buffer_, write_buffer_->BytesRemaining(),
+ base::Bind(&SSLSocketAdapter::OnWritten, base::Unretained(this)));
+
+ if (result > 0) {
+ write_buffer_->DidConsume(result);
+ if (!write_buffer_->BytesRemaining()) {
+ write_buffer_ = NULL;
+ return;
+ }
+ continue;
+ }
+
+ if (result == net::ERR_IO_PENDING) {
+ write_pending_ = true;
+ } else {
+ SetError(EINVAL);
+ ssl_state_ = SSLSTATE_ERROR;
+ }
+ return;
+ }
+}
+
void SSLSocketAdapter::OnConnectEvent(talk_base::AsyncSocket* socket) {
if (ssl_state_ != SSLSTATE_WAIT) {
AsyncSocketAdapter::OnConnectEvent(socket);