diff options
Diffstat (limited to 'remoting/jingle_glue/ssl_socket_adapter.cc')
-rw-r--r-- | remoting/jingle_glue/ssl_socket_adapter.cc | 184 |
1 files changed, 132 insertions, 52 deletions
diff --git a/remoting/jingle_glue/ssl_socket_adapter.cc b/remoting/jingle_glue/ssl_socket_adapter.cc index 9e5ef86..f070c05 100644 --- a/remoting/jingle_glue/ssl_socket_adapter.cc +++ b/remoting/jingle_glue/ssl_socket_adapter.cc @@ -27,9 +27,8 @@ SSLSocketAdapter::SSLSocketAdapter(AsyncSocket* socket) ignore_bad_cert_(false), cert_verifier_(net::CertVerifier::CreateDefault()), ssl_state_(SSLSTATE_NONE), - read_state_(IOSTATE_NONE), - write_state_(IOSTATE_NONE), - data_transferred_(0) { + read_pending_(false), + write_pending_(false) { transport_socket_ = new TransportSocket(socket, this); } @@ -83,64 +82,102 @@ int SSLSocketAdapter::BeginSSL() { } int SSLSocketAdapter::Send(const void* buf, size_t len) { - if (ssl_state_ != SSLSTATE_CONNECTED) { + if (ssl_state_ == SSLSTATE_ERROR) { + SetError(EINVAL); + return -1; + } + + if (ssl_state_ == SSLSTATE_NONE) { + // Propagate the call to underlying socket if SSL is not connected + // yet (connection is not encrypted until StartSSL() is called). return AsyncSocketAdapter::Send(buf, len); - } else { - scoped_refptr<net::IOBuffer> transport_buf(new net::IOBuffer(len)); - memcpy(transport_buf->data(), buf, len); + } - int result = ssl_socket_->Write(transport_buf, len, - net::CompletionCallback()); - if (result == net::ERR_IO_PENDING) { - SetError(EWOULDBLOCK); - } - transport_buf = NULL; - return result; + if (write_pending_) { + SetError(EWOULDBLOCK); + return -1; } + + write_buffer_ = new net::DrainableIOBuffer(new net::IOBuffer(len), len); + memcpy(write_buffer_->data(), buf, len); + + DoWrite(); + + return len; } int SSLSocketAdapter::Recv(void* buf, size_t len) { switch (ssl_state_) { - case SSLSTATE_NONE: + case SSLSTATE_NONE: { return AsyncSocketAdapter::Recv(buf, len); + } - case SSLSTATE_WAIT: + case SSLSTATE_WAIT: { SetError(EWOULDBLOCK); return -1; + } - case SSLSTATE_CONNECTED: - switch (read_state_) { - case IOSTATE_NONE: { - transport_buf_ = new net::IOBuffer(len); - int result = ssl_socket_->Read( - transport_buf_, len, - base::Bind(&SSLSocketAdapter::OnRead, base::Unretained(this))); - if (result >= 0) { - memcpy(buf, transport_buf_->data(), len); - } - - if (result == net::ERR_IO_PENDING) { - read_state_ = IOSTATE_PENDING; - SetError(EWOULDBLOCK); - } else { - if (result < 0) { - SetError(result); - VLOG(1) << "Socket error " << result; - } - transport_buf_ = NULL; - } - return result; - } - case IOSTATE_PENDING: + case SSLSTATE_CONNECTED: { + if (read_pending_) { + SetError(EWOULDBLOCK); + return -1; + } + + int bytes_read = 0; + + // Process any data we have left from the previous read. + if (read_buffer_) { + int size = std::min(read_buffer_->RemainingCapacity(), + static_cast<int>(len)); + memcpy(buf, read_buffer_->data(), size); + read_buffer_->set_offset(read_buffer_->offset() + size); + if (!read_buffer_->RemainingCapacity()) + read_buffer_ = NULL; + + if (size == static_cast<int>(len)) + return size; + + // If we didn't fill the caller's buffer then dispatch a new + // Read() in case there's more data ready. + buf = reinterpret_cast<char*>(buf) + size; + len -= size; + bytes_read = size; + DCHECK(!read_buffer_); + } + + // Dispatch a Read() request to the SSL layer. + read_buffer_ = new net::GrowableIOBuffer(); + read_buffer_->SetCapacity(len); + int result = ssl_socket_->Read( + read_buffer_, len, + base::Bind(&SSLSocketAdapter::OnRead, base::Unretained(this))); + if (result >= 0) + memcpy(buf, read_buffer_->data(), len); + + if (result == net::ERR_IO_PENDING) { + read_pending_ = true; + if (bytes_read) { + return bytes_read; + } else { SetError(EWOULDBLOCK); return -1; + } + } - case IOSTATE_COMPLETE: - memcpy(buf, transport_buf_->data(), len); - transport_buf_ = NULL; - read_state_ = IOSTATE_NONE; - return data_transferred_; + if (result < 0) { + SetError(EINVAL); + ssl_state_ = SSLSTATE_ERROR; + LOG(ERROR) << "Error reading from SSL socket " << result; + return -1; } + read_buffer_ = NULL; + return result + bytes_read; + } + + case SSLSTATE_ERROR: { + SetError(EINVAL); + return -1; + } } NOTREACHED(); @@ -157,19 +194,62 @@ void SSLSocketAdapter::OnConnected(int result) { } void SSLSocketAdapter::OnRead(int result) { - DCHECK(read_state_ == IOSTATE_PENDING); - read_state_ = IOSTATE_COMPLETE; - data_transferred_ = result; + DCHECK(read_pending_); + read_pending_ = false; + if (result > 0) { + DCHECK_GE(read_buffer_->capacity(), result); + read_buffer_->SetCapacity(result); + } else { + if (result < 0) + ssl_state_ = SSLSTATE_ERROR; + } AsyncSocketAdapter::OnReadEvent(this); } -void SSLSocketAdapter::OnWrite(int result) { - DCHECK(write_state_ == IOSTATE_PENDING); - write_state_ = IOSTATE_COMPLETE; - data_transferred_ = result; +void SSLSocketAdapter::OnWritten(int result) { + DCHECK(write_pending_); + write_pending_ = false; + if (result >= 0) { + write_buffer_->DidConsume(result); + if (!write_buffer_->BytesRemaining()) { + write_buffer_ = NULL; + } else { + DoWrite(); + } + } else { + ssl_state_ = SSLSTATE_ERROR; + } AsyncSocketAdapter::OnWriteEvent(this); } +void SSLSocketAdapter::DoWrite() { + DCHECK_GT(write_buffer_->BytesRemaining(), 0); + DCHECK(!write_pending_); + + while (true) { + int result = ssl_socket_->Write( + write_buffer_, write_buffer_->BytesRemaining(), + base::Bind(&SSLSocketAdapter::OnWritten, base::Unretained(this))); + + if (result > 0) { + write_buffer_->DidConsume(result); + if (!write_buffer_->BytesRemaining()) { + write_buffer_ = NULL; + return; + } + continue; + } + + if (result == net::ERR_IO_PENDING) { + write_pending_ = true; + } else { + SetError(EINVAL); + ssl_state_ = SSLSTATE_ERROR; + } + return; + } +} + void SSLSocketAdapter::OnConnectEvent(talk_base::AsyncSocket* socket) { if (ssl_state_ != SSLSTATE_WAIT) { AsyncSocketAdapter::OnConnectEvent(socket); |