diff options
Diffstat (limited to 'net')
50 files changed, 2838 insertions, 2801 deletions
diff --git a/net/base/upload_data.cc b/net/base/upload_data.cc index f148cac..1dd8bbc 100644 --- a/net/base/upload_data.cc +++ b/net/base/upload_data.cc @@ -28,6 +28,18 @@ UploadData::Element::~Element() { delete file_stream_; } +void UploadData::Element::SetToChunk(const char* bytes, int bytes_len) { + std::string chunk_length = StringPrintf("%X\r\n", bytes_len); + bytes_.clear(); + bytes_.insert(bytes_.end(), chunk_length.data(), + chunk_length.data() + chunk_length.length()); + bytes_.insert(bytes_.end(), bytes, bytes + bytes_len); + const char* crlf = "\r\n"; + bytes_.insert(bytes_.end(), crlf, crlf + 2); + type_ = TYPE_CHUNK; + is_last_chunk_ = (bytes_len == 0); +} + uint64 UploadData::Element::GetContentLength() { if (override_content_length_ || content_length_computed_) return content_length_; @@ -67,18 +79,6 @@ uint64 UploadData::Element::GetContentLength() { return content_length_; } -void UploadData::Element::SetToChunk(const char* bytes, int bytes_len) { - std::string chunk_length = StringPrintf("%X\r\n", bytes_len); - bytes_.clear(); - bytes_.insert(bytes_.end(), chunk_length.data(), - chunk_length.data() + chunk_length.length()); - bytes_.insert(bytes_.end(), bytes, bytes + bytes_len); - const char* crlf = "\r\n"; - bytes_.insert(bytes_.end(), crlf, crlf + 2); - type_ = TYPE_CHUNK; - is_last_chunk_ = (bytes_len == 0); -} - FileStream* UploadData::Element::NewFileStreamForReading() { // In common usage GetContentLength() will call this first and store the // result into |file_| and a subsequent call (from UploadDataStream) will diff --git a/net/base/upload_data_stream.cc b/net/base/upload_data_stream.cc index 1f77f06..9f7bdbb 100644 --- a/net/base/upload_data_stream.cc +++ b/net/base/upload_data_stream.cc @@ -12,6 +12,9 @@ namespace net { +UploadDataStream::~UploadDataStream() { +} + UploadDataStream* UploadDataStream::Create(UploadData* data, int* error_code) { scoped_ptr<UploadDataStream> stream(new UploadDataStream(data)); int rv = stream->FillBuf(); @@ -50,9 +53,6 @@ UploadDataStream::UploadDataStream(UploadData* data) eof_(false) { } -UploadDataStream::~UploadDataStream() { -} - int UploadDataStream::FillBuf() { std::vector<UploadData::Element>& elements = *data_->elements(); diff --git a/net/base/x509_certificate.cc b/net/base/x509_certificate.cc index b705790..6469537 100644 --- a/net/base/x509_certificate.cc +++ b/net/base/x509_certificate.cc @@ -121,6 +121,19 @@ bool X509Certificate::LessThan::operator()(X509Certificate* lhs, return fingerprint_functor(lhs->fingerprint_, rhs->fingerprint_); } +X509Certificate::X509Certificate(const std::string& subject, + const std::string& issuer, + base::Time start_date, + base::Time expiration_date) + : subject_(subject), + issuer_(issuer), + valid_start_(start_date), + valid_expiry_(expiration_date), + cert_handle_(NULL), + source_(SOURCE_UNUSED) { + memset(fingerprint_.data, 0, sizeof(fingerprint_.data)); +} + // static X509Certificate* X509Certificate::CreateFromHandle( OSCertHandle cert_handle, @@ -283,40 +296,6 @@ CertificateList X509Certificate::CreateCertificateListFromBytes( return results; } -X509Certificate::X509Certificate(OSCertHandle cert_handle, - Source source, - const OSCertHandles& intermediates) - : cert_handle_(DupOSCertHandle(cert_handle)), - source_(source) { - // Copy/retain the intermediate cert handles. - for (size_t i = 0; i < intermediates.size(); ++i) - intermediate_ca_certs_.push_back(DupOSCertHandle(intermediates[i])); - // Platform-specific initialization. - Initialize(); -} - -X509Certificate::X509Certificate(const std::string& subject, - const std::string& issuer, - base::Time start_date, - base::Time expiration_date) - : subject_(subject), - issuer_(issuer), - valid_start_(start_date), - valid_expiry_(expiration_date), - cert_handle_(NULL), - source_(SOURCE_UNUSED) { - memset(fingerprint_.data, 0, sizeof(fingerprint_.data)); -} - -X509Certificate::~X509Certificate() { - // We might not be in the cache, but it is safe to remove ourselves anyway. - g_x509_certificate_cache.Get().Remove(this); - if (cert_handle_) - FreeOSCertHandle(cert_handle_); - for (size_t i = 0; i < intermediate_ca_certs_.size(); ++i) - FreeOSCertHandle(intermediate_ca_certs_[i]); -} - bool X509Certificate::HasExpired() const { return base::Time::Now() > valid_expiry(); } @@ -345,4 +324,25 @@ bool X509Certificate::HasIntermediateCertificates(const OSCertHandles& certs) { return true; } +X509Certificate::X509Certificate(OSCertHandle cert_handle, + Source source, + const OSCertHandles& intermediates) + : cert_handle_(DupOSCertHandle(cert_handle)), + source_(source) { + // Copy/retain the intermediate cert handles. + for (size_t i = 0; i < intermediates.size(); ++i) + intermediate_ca_certs_.push_back(DupOSCertHandle(intermediates[i])); + // Platform-specific initialization. + Initialize(); +} + +X509Certificate::~X509Certificate() { + // We might not be in the cache, but it is safe to remove ourselves anyway. + g_x509_certificate_cache.Get().Remove(this); + if (cert_handle_) + FreeOSCertHandle(cert_handle_); + for (size_t i = 0; i < intermediate_ca_certs_.size(); ++i) + FreeOSCertHandle(intermediate_ca_certs_[i]); +} + } // namespace net diff --git a/net/disk_cache/backend_impl.cc b/net/disk_cache/backend_impl.cc index 89890e9..0a6ad778 100644 --- a/net/disk_cache/backend_impl.cc +++ b/net/disk_cache/backend_impl.cc @@ -344,34 +344,6 @@ int PreferedCacheSize(int64 available) { // ------------------------------------------------------------------------ -// If the initialization of the cache fails, and force is true, we will discard -// the whole cache and create a new one. In order to process a potentially large -// number of files, we'll rename the cache folder to old_ + original_name + -// number, (located on the same parent folder), and spawn a worker thread to -// delete all the files on all the stale cache folders. The whole process can -// still fail if we are not able to rename the cache folder (for instance due to -// a sharing violation), and in that case a cache for this profile (on the -// desired path) cannot be created. -// -// Static. -int BackendImpl::CreateBackend(const FilePath& full_path, bool force, - int max_bytes, net::CacheType type, - uint32 flags, base::MessageLoopProxy* thread, - net::NetLog* net_log, Backend** backend, - CompletionCallback* callback) { - DCHECK(callback); - CacheCreator* creator = new CacheCreator(full_path, force, max_bytes, type, - flags, thread, net_log, backend, - callback); - // This object will self-destroy when finished. - return creator->Run(); -} - -int BackendImpl::Init(CompletionCallback* callback) { - background_queue_.Init(callback); - return net::ERR_IO_PENDING; -} - BackendImpl::BackendImpl(const FilePath& path, base::MessageLoopProxy* cache_thread, net::NetLog* net_log) @@ -436,7 +408,33 @@ BackendImpl::~BackendImpl() { } } -// ------------------------------------------------------------------------ +// If the initialization of the cache fails, and force is true, we will discard +// the whole cache and create a new one. In order to process a potentially large +// number of files, we'll rename the cache folder to old_ + original_name + +// number, (located on the same parent folder), and spawn a worker thread to +// delete all the files on all the stale cache folders. The whole process can +// still fail if we are not able to rename the cache folder (for instance due to +// a sharing violation), and in that case a cache for this profile (on the +// desired path) cannot be created. +// +// Static. +int BackendImpl::CreateBackend(const FilePath& full_path, bool force, + int max_bytes, net::CacheType type, + uint32 flags, base::MessageLoopProxy* thread, + net::NetLog* net_log, Backend** backend, + CompletionCallback* callback) { + DCHECK(callback); + CacheCreator* creator = new CacheCreator(full_path, force, max_bytes, type, + flags, thread, net_log, backend, + callback); + // This object will self-destroy when finished. + return creator->Run(); +} + +int BackendImpl::Init(CompletionCallback* callback) { + background_queue_.Init(callback); + return net::ERR_IO_PENDING; +} int BackendImpl::SyncInit() { DCHECK(!init_); diff --git a/net/disk_cache/file_posix.cc b/net/disk_cache/file_posix.cc index 01dafd3..740d108 100644 --- a/net/disk_cache/file_posix.cc +++ b/net/disk_cache/file_posix.cc @@ -189,11 +189,6 @@ bool File::Init(const FilePath& name) { return true; } -File::~File() { - if (platform_file_) - close(platform_file_); -} - base::PlatformFile File::platform_file() const { return platform_file_; } @@ -255,19 +250,6 @@ bool File::Write(const void* buffer, size_t buffer_len, size_t offset, return AsyncWrite(buffer, buffer_len, offset, callback, completed); } -bool File::AsyncWrite(const void* buffer, size_t buffer_len, size_t offset, - FileIOCallback* callback, bool* completed) { - DCHECK(init_); - if (buffer_len > ULONG_MAX || offset > ULONG_MAX) - return false; - - GetFileInFlightIO()->PostWrite(this, buffer, buffer_len, offset, callback); - - if (completed) - *completed = false; - return true; -} - bool File::SetLength(size_t length) { DCHECK(init_); if (length > ULONG_MAX) @@ -290,4 +272,22 @@ void File::WaitForPendingIO(int* num_pending_io) { DeleteFileInFlightIO(); } +File::~File() { + if (platform_file_) + close(platform_file_); +} + +bool File::AsyncWrite(const void* buffer, size_t buffer_len, size_t offset, + FileIOCallback* callback, bool* completed) { + DCHECK(init_); + if (buffer_len > ULONG_MAX || offset > ULONG_MAX) + return false; + + GetFileInFlightIO()->PostWrite(this, buffer, buffer_len, offset, callback); + + if (completed) + *completed = false; + return true; +} + } // namespace disk_cache diff --git a/net/disk_cache/mapped_file_posix.cc b/net/disk_cache/mapped_file_posix.cc index f9a361b..9abfa5cd 100644 --- a/net/disk_cache/mapped_file_posix.cc +++ b/net/disk_cache/mapped_file_posix.cc @@ -32,16 +32,6 @@ void* MappedFile::Init(const FilePath& name, size_t size) { return buffer_; } -MappedFile::~MappedFile() { - if (!init_) - return; - - if (buffer_) { - int ret = munmap(buffer_, view_size_); - DCHECK(0 == ret); - } -} - bool MappedFile::Load(const FileBlock* block) { size_t offset = block->offset() + view_size_; return Read(block->buffer(), block->size(), offset); @@ -52,4 +42,14 @@ bool MappedFile::Store(const FileBlock* block) { return Write(block->buffer(), block->size(), offset); } +MappedFile::~MappedFile() { + if (!init_) + return; + + if (buffer_) { + int ret = munmap(buffer_, view_size_); + DCHECK(0 == ret); + } +} + } // namespace disk_cache diff --git a/net/disk_cache/rankings.cc b/net/disk_cache/rankings.cc index 801d387..b10dac6 100644 --- a/net/disk_cache/rankings.cc +++ b/net/disk_cache/rankings.cc @@ -228,58 +228,6 @@ void Rankings::Reset() { control_data_ = NULL; } -bool Rankings::GetRanking(CacheRankingsBlock* rankings) { - if (!rankings->address().is_initialized()) - return false; - - TimeTicks start = TimeTicks::Now(); - if (!rankings->Load()) - return false; - - if (!SanityCheck(rankings, true)) { - backend_->CriticalError(ERR_INVALID_LINKS); - return false; - } - - backend_->OnEvent(Stats::OPEN_RANKINGS); - - // "dummy" is the old "pointer" value, so it has to be 0. - if (!rankings->Data()->dirty && !rankings->Data()->dummy) - return true; - - EntryImpl* entry = backend_->GetOpenEntry(rankings); - if (backend_->GetCurrentEntryId() != rankings->Data()->dirty || !entry) { - // We cannot trust this entry, but we cannot initiate a cleanup from this - // point (we may be in the middle of a cleanup already). Just get rid of - // the invalid pointer and continue; the entry will be deleted when detected - // from a regular open/create path. - rankings->Data()->dummy = 0; - rankings->Data()->dirty = backend_->GetCurrentEntryId() - 1; - if (!rankings->Data()->dirty) - rankings->Data()->dirty--; - return true; - } - - // Note that we should not leave this module without deleting rankings first. - rankings->SetData(entry->rankings()->Data()); - - CACHE_UMA(AGE_MS, "GetRankings", 0, start); - return true; -} - -void Rankings::ConvertToLongLived(CacheRankingsBlock* rankings) { - if (rankings->own_data()) - return; - - // We cannot return a shared node because we are not keeping a reference - // to the entry that owns the buffer. Make this node a copy of the one that - // we have, and let the iterator logic update it when the entry changes. - CacheRankingsBlock temp(NULL, Addr(0)); - *temp.Data() = *rankings->Data(); - rankings->StopSharingData(); - *rankings->Data() = *temp.Data(); -} - void Rankings::Insert(CacheRankingsBlock* node, bool modified, List list) { Trace("Insert 0x%x l %d", node->address().value(), list); DCHECK(node->HasData()); @@ -443,116 +391,6 @@ void Rankings::UpdateRank(CacheRankingsBlock* node, bool modified, List list) { CACHE_UMA(AGE_MS, "UpdateRank", 0, start); } -void Rankings::CompleteTransaction() { - Addr node_addr(static_cast<CacheAddr>(control_data_->transaction)); - if (!node_addr.is_initialized() || node_addr.is_separate_file()) { - NOTREACHED(); - LOG(ERROR) << "Invalid rankings info."; - return; - } - - Trace("CompleteTransaction 0x%x", node_addr.value()); - - CacheRankingsBlock node(backend_->File(node_addr), node_addr); - if (!node.Load()) - return; - - node.Data()->dummy = 0; - node.Store(); - - Addr& my_head = heads_[control_data_->operation_list]; - Addr& my_tail = tails_[control_data_->operation_list]; - - // We want to leave the node inside the list. The entry must me marked as - // dirty, and will be removed later. Otherwise, we'll get assertions when - // attempting to remove the dirty entry. - if (INSERT == control_data_->operation) { - Trace("FinishInsert h:0x%x t:0x%x", my_head.value(), my_tail.value()); - FinishInsert(&node); - } else if (REMOVE == control_data_->operation) { - Trace("RevertRemove h:0x%x t:0x%x", my_head.value(), my_tail.value()); - RevertRemove(&node); - } else { - NOTREACHED(); - LOG(ERROR) << "Invalid operation to recover."; - } -} - -void Rankings::FinishInsert(CacheRankingsBlock* node) { - control_data_->transaction = 0; - control_data_->operation = 0; - Addr& my_head = heads_[control_data_->operation_list]; - Addr& my_tail = tails_[control_data_->operation_list]; - if (my_head.value() != node->address().value()) { - if (my_tail.value() == node->address().value()) { - // This part will be skipped by the logic of Insert. - node->Data()->next = my_tail.value(); - } - - Insert(node, true, static_cast<List>(control_data_->operation_list)); - } - - // Tell the backend about this entry. - backend_->RecoveredEntry(node); -} - -void Rankings::RevertRemove(CacheRankingsBlock* node) { - Addr next_addr(node->Data()->next); - Addr prev_addr(node->Data()->prev); - if (!next_addr.is_initialized() || !prev_addr.is_initialized()) { - // The operation actually finished. Nothing to do. - control_data_->transaction = 0; - return; - } - if (next_addr.is_separate_file() || prev_addr.is_separate_file()) { - NOTREACHED(); - LOG(WARNING) << "Invalid rankings info."; - control_data_->transaction = 0; - return; - } - - CacheRankingsBlock next(backend_->File(next_addr), next_addr); - CacheRankingsBlock prev(backend_->File(prev_addr), prev_addr); - if (!next.Load() || !prev.Load()) - return; - - CacheAddr node_value = node->address().value(); - DCHECK(prev.Data()->next == node_value || - prev.Data()->next == prev_addr.value() || - prev.Data()->next == next.address().value()); - DCHECK(next.Data()->prev == node_value || - next.Data()->prev == next_addr.value() || - next.Data()->prev == prev.address().value()); - - if (node_value != prev_addr.value()) - prev.Data()->next = node_value; - if (node_value != next_addr.value()) - next.Data()->prev = node_value; - - List my_list = static_cast<List>(control_data_->operation_list); - Addr& my_head = heads_[my_list]; - Addr& my_tail = tails_[my_list]; - if (!my_head.is_initialized() || !my_tail.is_initialized()) { - my_head.set_value(node_value); - my_tail.set_value(node_value); - WriteHead(my_list); - WriteTail(my_list); - } else if (my_head.value() == next.address().value()) { - my_head.set_value(node_value); - prev.Data()->next = next.address().value(); - WriteHead(my_list); - } else if (my_tail.value() == prev.address().value()) { - my_tail.set_value(node_value); - next.Data()->prev = prev.address().value(); - WriteTail(my_list); - } - - next.Store(); - prev.Store(); - control_data_->transaction = 0; - control_data_->operation = 0; -} - CacheRankingsBlock* Rankings::GetNext(CacheRankingsBlock* node, List list) { ScopedRankingsBlock next(this); if (!node) { @@ -691,6 +529,168 @@ void Rankings::WriteTail(List list) { control_data_->tails[list] = tails_[list].value(); } +bool Rankings::GetRanking(CacheRankingsBlock* rankings) { + if (!rankings->address().is_initialized()) + return false; + + TimeTicks start = TimeTicks::Now(); + if (!rankings->Load()) + return false; + + if (!SanityCheck(rankings, true)) { + backend_->CriticalError(ERR_INVALID_LINKS); + return false; + } + + backend_->OnEvent(Stats::OPEN_RANKINGS); + + // "dummy" is the old "pointer" value, so it has to be 0. + if (!rankings->Data()->dirty && !rankings->Data()->dummy) + return true; + + EntryImpl* entry = backend_->GetOpenEntry(rankings); + if (backend_->GetCurrentEntryId() != rankings->Data()->dirty || !entry) { + // We cannot trust this entry, but we cannot initiate a cleanup from this + // point (we may be in the middle of a cleanup already). Just get rid of + // the invalid pointer and continue; the entry will be deleted when detected + // from a regular open/create path. + rankings->Data()->dummy = 0; + rankings->Data()->dirty = backend_->GetCurrentEntryId() - 1; + if (!rankings->Data()->dirty) + rankings->Data()->dirty--; + return true; + } + + // Note that we should not leave this module without deleting rankings first. + rankings->SetData(entry->rankings()->Data()); + + CACHE_UMA(AGE_MS, "GetRankings", 0, start); + return true; +} + +void Rankings::ConvertToLongLived(CacheRankingsBlock* rankings) { + if (rankings->own_data()) + return; + + // We cannot return a shared node because we are not keeping a reference + // to the entry that owns the buffer. Make this node a copy of the one that + // we have, and let the iterator logic update it when the entry changes. + CacheRankingsBlock temp(NULL, Addr(0)); + *temp.Data() = *rankings->Data(); + rankings->StopSharingData(); + *rankings->Data() = *temp.Data(); +} + +void Rankings::CompleteTransaction() { + Addr node_addr(static_cast<CacheAddr>(control_data_->transaction)); + if (!node_addr.is_initialized() || node_addr.is_separate_file()) { + NOTREACHED(); + LOG(ERROR) << "Invalid rankings info."; + return; + } + + Trace("CompleteTransaction 0x%x", node_addr.value()); + + CacheRankingsBlock node(backend_->File(node_addr), node_addr); + if (!node.Load()) + return; + + node.Data()->dummy = 0; + node.Store(); + + Addr& my_head = heads_[control_data_->operation_list]; + Addr& my_tail = tails_[control_data_->operation_list]; + + // We want to leave the node inside the list. The entry must me marked as + // dirty, and will be removed later. Otherwise, we'll get assertions when + // attempting to remove the dirty entry. + if (INSERT == control_data_->operation) { + Trace("FinishInsert h:0x%x t:0x%x", my_head.value(), my_tail.value()); + FinishInsert(&node); + } else if (REMOVE == control_data_->operation) { + Trace("RevertRemove h:0x%x t:0x%x", my_head.value(), my_tail.value()); + RevertRemove(&node); + } else { + NOTREACHED(); + LOG(ERROR) << "Invalid operation to recover."; + } +} + +void Rankings::FinishInsert(CacheRankingsBlock* node) { + control_data_->transaction = 0; + control_data_->operation = 0; + Addr& my_head = heads_[control_data_->operation_list]; + Addr& my_tail = tails_[control_data_->operation_list]; + if (my_head.value() != node->address().value()) { + if (my_tail.value() == node->address().value()) { + // This part will be skipped by the logic of Insert. + node->Data()->next = my_tail.value(); + } + + Insert(node, true, static_cast<List>(control_data_->operation_list)); + } + + // Tell the backend about this entry. + backend_->RecoveredEntry(node); +} + +void Rankings::RevertRemove(CacheRankingsBlock* node) { + Addr next_addr(node->Data()->next); + Addr prev_addr(node->Data()->prev); + if (!next_addr.is_initialized() || !prev_addr.is_initialized()) { + // The operation actually finished. Nothing to do. + control_data_->transaction = 0; + return; + } + if (next_addr.is_separate_file() || prev_addr.is_separate_file()) { + NOTREACHED(); + LOG(WARNING) << "Invalid rankings info."; + control_data_->transaction = 0; + return; + } + + CacheRankingsBlock next(backend_->File(next_addr), next_addr); + CacheRankingsBlock prev(backend_->File(prev_addr), prev_addr); + if (!next.Load() || !prev.Load()) + return; + + CacheAddr node_value = node->address().value(); + DCHECK(prev.Data()->next == node_value || + prev.Data()->next == prev_addr.value() || + prev.Data()->next == next.address().value()); + DCHECK(next.Data()->prev == node_value || + next.Data()->prev == next_addr.value() || + next.Data()->prev == prev.address().value()); + + if (node_value != prev_addr.value()) + prev.Data()->next = node_value; + if (node_value != next_addr.value()) + next.Data()->prev = node_value; + + List my_list = static_cast<List>(control_data_->operation_list); + Addr& my_head = heads_[my_list]; + Addr& my_tail = tails_[my_list]; + if (!my_head.is_initialized() || !my_tail.is_initialized()) { + my_head.set_value(node_value); + my_tail.set_value(node_value); + WriteHead(my_list); + WriteTail(my_list); + } else if (my_head.value() == next.address().value()) { + my_head.set_value(node_value); + prev.Data()->next = next.address().value(); + WriteHead(my_list); + } else if (my_tail.value() == prev.address().value()) { + my_tail.set_value(node_value); + next.Data()->prev = prev.address().value(); + WriteTail(my_list); + } + + next.Store(); + prev.Store(); + control_data_->transaction = 0; + control_data_->operation = 0; +} + bool Rankings::CheckEntry(CacheRankingsBlock* rankings) { if (!rankings->Data()->dummy) return true; diff --git a/net/disk_cache/stats.cc b/net/disk_cache/stats.cc index 5222112..d9a9d12 100644 --- a/net/disk_cache/stats.cc +++ b/net/disk_cache/stats.cc @@ -116,6 +116,12 @@ bool CreateStats(BackendImpl* backend, Addr* address, OnDiskStats* stats) { return StoreStats(backend, *address, stats); } +Stats::Stats() : backend_(NULL) { +} + +Stats::~Stats() { +} + bool Stats::Init(BackendImpl* backend, uint32* storage_addr) { OnDiskStats stats; Addr address(*storage_addr); @@ -153,86 +159,6 @@ bool Stats::Init(BackendImpl* backend, uint32* storage_addr) { return true; } -Stats::Stats() : backend_(NULL) { -} - -Stats::~Stats() { -} - -// The array will be filled this way: -// index size -// 0 [0, 1024) -// 1 [1024, 2048) -// 2 [2048, 4096) -// 3 [4K, 6K) -// ... -// 10 [18K, 20K) -// 11 [20K, 24K) -// 12 [24k, 28K) -// ... -// 15 [36k, 40K) -// 16 [40k, 64K) -// 17 [64K, 128K) -// 18 [128K, 256K) -// ... -// 23 [4M, 8M) -// 24 [8M, 16M) -// 25 [16M, 32M) -// 26 [32M, 64M) -// 27 [64M, ...) -int Stats::GetStatsBucket(int32 size) { - if (size < 1024) - return 0; - - // 10 slots more, until 20K. - if (size < 20 * 1024) - return size / 2048 + 1; - - // 5 slots more, from 20K to 40K. - if (size < 40 * 1024) - return (size - 20 * 1024) / 4096 + 11; - - // From this point on, use a logarithmic scale. - int result = LogBase2(size) + 1; - - COMPILE_ASSERT(kDataSizesLength > 16, update_the_scale); - if (result >= kDataSizesLength) - result = kDataSizesLength - 1; - - return result; -} - -int Stats::GetBucketRange(size_t i) const { - if (i < 2) - return static_cast<int>(1024 * i); - - if (i < 12) - return static_cast<int>(2048 * (i - 1)); - - if (i < 17) - return static_cast<int>(4096 * (i - 11)) + 20 * 1024; - - int n = 64 * 1024; - if (i > static_cast<size_t>(kDataSizesLength)) { - NOTREACHED(); - i = kDataSizesLength; - } - - i -= 17; - n <<= i; - return n; -} - -void Stats::Snapshot(StatsHistogram::StatsSamples* samples) const { - samples->GetCounts()->resize(kDataSizesLength); - for (int i = 0; i < kDataSizesLength; i++) { - int count = data_sizes_[i]; - if (count < 0) - count = 0; - samples->GetCounts()->at(i) = count; - } -} - void Stats::ModifyStorageStats(int32 old_size, int32 new_size) { // We keep a counter of the data block size on an array where each entry is // the adjusted log base 2 of the size. The first entry counts blocks of 256 @@ -286,15 +212,6 @@ int Stats::GetResurrectRatio() const { return GetRatio(RESURRECT_HIT, CREATE_HIT); } -int Stats::GetRatio(Counters hit, Counters miss) const { - int64 ratio = GetCounter(hit) * 100; - if (!ratio) - return 0; - - ratio /= (GetCounter(hit) + GetCounter(miss)); - return static_cast<int>(ratio); -} - void Stats::ResetRatios() { SetCounter(OPEN_HIT, 0); SetCounter(OPEN_MISS, 0); @@ -326,4 +243,87 @@ void Stats::Store() { StoreStats(backend_, address, &stats); } +int Stats::GetBucketRange(size_t i) const { + if (i < 2) + return static_cast<int>(1024 * i); + + if (i < 12) + return static_cast<int>(2048 * (i - 1)); + + if (i < 17) + return static_cast<int>(4096 * (i - 11)) + 20 * 1024; + + int n = 64 * 1024; + if (i > static_cast<size_t>(kDataSizesLength)) { + NOTREACHED(); + i = kDataSizesLength; + } + + i -= 17; + n <<= i; + return n; +} + +void Stats::Snapshot(StatsHistogram::StatsSamples* samples) const { + samples->GetCounts()->resize(kDataSizesLength); + for (int i = 0; i < kDataSizesLength; i++) { + int count = data_sizes_[i]; + if (count < 0) + count = 0; + samples->GetCounts()->at(i) = count; + } +} + +// The array will be filled this way: +// index size +// 0 [0, 1024) +// 1 [1024, 2048) +// 2 [2048, 4096) +// 3 [4K, 6K) +// ... +// 10 [18K, 20K) +// 11 [20K, 24K) +// 12 [24k, 28K) +// ... +// 15 [36k, 40K) +// 16 [40k, 64K) +// 17 [64K, 128K) +// 18 [128K, 256K) +// ... +// 23 [4M, 8M) +// 24 [8M, 16M) +// 25 [16M, 32M) +// 26 [32M, 64M) +// 27 [64M, ...) +int Stats::GetStatsBucket(int32 size) { + if (size < 1024) + return 0; + + // 10 slots more, until 20K. + if (size < 20 * 1024) + return size / 2048 + 1; + + // 5 slots more, from 20K to 40K. + if (size < 40 * 1024) + return (size - 20 * 1024) / 4096 + 11; + + // From this point on, use a logarithmic scale. + int result = LogBase2(size) + 1; + + COMPILE_ASSERT(kDataSizesLength > 16, update_the_scale); + if (result >= kDataSizesLength) + result = kDataSizesLength - 1; + + return result; +} + +int Stats::GetRatio(Counters hit, Counters miss) const { + int64 ratio = GetCounter(hit) * 100; + if (!ratio) + return 0; + + ratio /= (GetCounter(hit) + GetCounter(miss)); + return static_cast<int>(ratio); +} + } // namespace disk_cache diff --git a/net/ftp/ftp_directory_listing_parser_vms.h b/net/ftp/ftp_directory_listing_parser_vms.h index 118365d..6f7fb73 100644 --- a/net/ftp/ftp_directory_listing_parser_vms.h +++ b/net/ftp/ftp_directory_listing_parser_vms.h @@ -26,10 +26,6 @@ class FtpDirectoryListingParserVms : public FtpDirectoryListingParser { virtual FtpDirectoryListingEntry PopEntry(); private: - // Consumes listing line which is expected to be a directory listing entry - // (and not a comment etc). Returns true on success. - bool ConsumeEntryLine(const string16& line); - enum State { STATE_INITIAL, @@ -46,7 +42,13 @@ class FtpDirectoryListingParserVms : public FtpDirectoryListingParser { // Indicates that we have successfully received all parts of the listing. STATE_END, - } state_; + }; + + // Consumes listing line which is expected to be a directory listing entry + // (and not a comment etc). Returns true on success. + bool ConsumeEntryLine(const string16& line); + + State state_; // VMS can use two physical lines if the filename is long. The first line will // contain the filename, and the second line everything else. Store the diff --git a/net/ftp/ftp_network_transaction.cc b/net/ftp/ftp_network_transaction.cc index 0285e08..d012818 100644 --- a/net/ftp/ftp_network_transaction.cc +++ b/net/ftp/ftp_network_transaction.cc @@ -204,6 +204,20 @@ FtpNetworkTransaction::FtpNetworkTransaction( FtpNetworkTransaction::~FtpNetworkTransaction() { } +int FtpNetworkTransaction::Stop(int error) { + if (command_sent_ == COMMAND_QUIT) + return error; + + next_state_ = STATE_CTRL_WRITE_QUIT; + last_error_ = error; + return OK; +} + +int FtpNetworkTransaction::RestartIgnoringLastError( + CompletionCallback* callback) { + return ERR_NOT_IMPLEMENTED; +} + int FtpNetworkTransaction::Start(const FtpRequestInfo* request_info, CompletionCallback* callback, const BoundNetLog& net_log) { @@ -226,15 +240,6 @@ int FtpNetworkTransaction::Start(const FtpRequestInfo* request_info, return rv; } -int FtpNetworkTransaction::Stop(int error) { - if (command_sent_ == COMMAND_QUIT) - return error; - - next_state_ = STATE_CTRL_WRITE_QUIT; - last_error_ = error; - return OK; -} - int FtpNetworkTransaction::RestartWithAuth(const string16& username, const string16& password, CompletionCallback* callback) { @@ -250,11 +255,6 @@ int FtpNetworkTransaction::RestartWithAuth(const string16& username, return rv; } -int FtpNetworkTransaction::RestartIgnoringLastError( - CompletionCallback* callback) { - return ERR_NOT_IMPLEMENTED; -} - int FtpNetworkTransaction::Read(IOBuffer* buf, int buf_len, CompletionCallback* callback) { @@ -302,34 +302,37 @@ uint64 FtpNetworkTransaction::GetUploadProgress() const { return 0; } -// Used to prepare and send FTP command. -int FtpNetworkTransaction::SendFtpCommand(const std::string& command, - Command cmd) { - // If we send a new command when we still have unprocessed responses - // for previous commands, the response receiving code will have no way to know - // which responses are for which command. - DCHECK(!ctrl_response_buffer_->ResponseAvailable()); - - DCHECK(!write_command_buf_); - DCHECK(!write_buf_); - - if (!IsValidFTPCommandString(command)) { - // Callers should validate the command themselves and return a more specific - // error code. - NOTREACHED(); - return Stop(ERR_UNEXPECTED); - } +void FtpNetworkTransaction::ResetStateForRestart() { + command_sent_ = COMMAND_NONE; + user_callback_ = NULL; + response_ = FtpResponseInfo(); + read_ctrl_buf_ = new IOBuffer(kCtrlBufLen); + ctrl_response_buffer_.reset(new FtpCtrlResponseBuffer()); + read_data_buf_ = NULL; + read_data_buf_len_ = 0; + if (write_buf_) + write_buf_->SetOffset(0); + last_error_ = OK; + data_connection_port_ = 0; + ctrl_socket_.reset(); + data_socket_.reset(); + next_state_ = STATE_NONE; +} - command_sent_ = cmd; +void FtpNetworkTransaction::DoCallback(int rv) { + DCHECK(rv != ERR_IO_PENDING); + DCHECK(user_callback_); - write_command_buf_ = new IOBufferWithSize(command.length() + 2); - write_buf_ = new DrainableIOBuffer(write_command_buf_, - write_command_buf_->size()); - memcpy(write_command_buf_->data(), command.data(), command.length()); - memcpy(write_command_buf_->data() + command.length(), kCRLF, 2); + // Since Run may result in Read being called, clear callback_ up front. + CompletionCallback* c = user_callback_; + user_callback_ = NULL; + c->Run(rv); +} - next_state_ = STATE_CTRL_WRITE; - return OK; +void FtpNetworkTransaction::OnIOComplete(int result) { + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) + DoCallback(rv); } int FtpNetworkTransaction::ProcessCtrlResponse() { @@ -403,37 +406,34 @@ int FtpNetworkTransaction::ProcessCtrlResponse() { return rv; } -void FtpNetworkTransaction::ResetStateForRestart() { - command_sent_ = COMMAND_NONE; - user_callback_ = NULL; - response_ = FtpResponseInfo(); - read_ctrl_buf_ = new IOBuffer(kCtrlBufLen); - ctrl_response_buffer_.reset(new FtpCtrlResponseBuffer()); - read_data_buf_ = NULL; - read_data_buf_len_ = 0; - if (write_buf_) - write_buf_->SetOffset(0); - last_error_ = OK; - data_connection_port_ = 0; - ctrl_socket_.reset(); - data_socket_.reset(); - next_state_ = STATE_NONE; -} +// Used to prepare and send FTP command. +int FtpNetworkTransaction::SendFtpCommand(const std::string& command, + Command cmd) { + // If we send a new command when we still have unprocessed responses + // for previous commands, the response receiving code will have no way to know + // which responses are for which command. + DCHECK(!ctrl_response_buffer_->ResponseAvailable()); -void FtpNetworkTransaction::DoCallback(int rv) { - DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_callback_); + DCHECK(!write_command_buf_); + DCHECK(!write_buf_); - // Since Run may result in Read being called, clear callback_ up front. - CompletionCallback* c = user_callback_; - user_callback_ = NULL; - c->Run(rv); -} + if (!IsValidFTPCommandString(command)) { + // Callers should validate the command themselves and return a more specific + // error code. + NOTREACHED(); + return Stop(ERR_UNEXPECTED); + } -void FtpNetworkTransaction::OnIOComplete(int result) { - int rv = DoLoop(result); - if (rv != ERR_IO_PENDING) - DoCallback(rv); + command_sent_ = cmd; + + write_command_buf_ = new IOBufferWithSize(command.length() + 2); + write_buf_ = new DrainableIOBuffer(write_command_buf_, + write_command_buf_->size()); + memcpy(write_command_buf_->data(), command.data(), command.length()); + memcpy(write_command_buf_->data() + command.length(), kCRLF, 2); + + next_state_ = STATE_CTRL_WRITE; + return OK; } std::string FtpNetworkTransaction::GetRequestPathForFtpCommand( @@ -947,56 +947,6 @@ int FtpNetworkTransaction::ProcessResponsePASV( return OK; } -// SIZE command -int FtpNetworkTransaction::DoCtrlWriteSIZE() { - std::string command = "SIZE " + GetRequestPathForFtpCommand(false); - next_state_ = STATE_CTRL_READ; - return SendFtpCommand(command, COMMAND_SIZE); -} - -int FtpNetworkTransaction::ProcessResponseSIZE( - const FtpCtrlResponse& response) { - switch (GetErrorClass(response.status_code)) { - case ERROR_CLASS_INITIATED: - break; - case ERROR_CLASS_OK: - if (response.lines.size() != 1) - return Stop(ERR_INVALID_RESPONSE); - int64 size; - if (!base::StringToInt64(response.lines[0], &size)) - return Stop(ERR_INVALID_RESPONSE); - if (size < 0) - return Stop(ERR_INVALID_RESPONSE); - - // A successful response to SIZE does not mean the resource is a file. - // Some FTP servers (for example, the qnx one) send a SIZE even for - // directories. - response_.expected_content_size = size; - break; - case ERROR_CLASS_INFO_NEEDED: - break; - case ERROR_CLASS_TRANSIENT_ERROR: - break; - case ERROR_CLASS_PERMANENT_ERROR: - // It's possible that SIZE failed because the path is a directory. - if (resource_type_ == RESOURCE_TYPE_UNKNOWN && - response.status_code != 550) { - return Stop(GetNetErrorCodeForFtpResponseCode(response.status_code)); - } - break; - default: - NOTREACHED(); - return Stop(ERR_UNEXPECTED); - } - - if (resource_type_ == RESOURCE_TYPE_FILE) - next_state_ = STATE_CTRL_WRITE_RETR; - else - next_state_ = STATE_CTRL_WRITE_CWD; - - return OK; -} - // RETR command int FtpNetworkTransaction::DoCtrlWriteRETR() { std::string command = "RETR " + GetRequestPathForFtpCommand(false); @@ -1048,6 +998,56 @@ int FtpNetworkTransaction::ProcessResponseRETR( return OK; } +// SIZE command +int FtpNetworkTransaction::DoCtrlWriteSIZE() { + std::string command = "SIZE " + GetRequestPathForFtpCommand(false); + next_state_ = STATE_CTRL_READ; + return SendFtpCommand(command, COMMAND_SIZE); +} + +int FtpNetworkTransaction::ProcessResponseSIZE( + const FtpCtrlResponse& response) { + switch (GetErrorClass(response.status_code)) { + case ERROR_CLASS_INITIATED: + break; + case ERROR_CLASS_OK: + if (response.lines.size() != 1) + return Stop(ERR_INVALID_RESPONSE); + int64 size; + if (!base::StringToInt64(response.lines[0], &size)) + return Stop(ERR_INVALID_RESPONSE); + if (size < 0) + return Stop(ERR_INVALID_RESPONSE); + + // A successful response to SIZE does not mean the resource is a file. + // Some FTP servers (for example, the qnx one) send a SIZE even for + // directories. + response_.expected_content_size = size; + break; + case ERROR_CLASS_INFO_NEEDED: + break; + case ERROR_CLASS_TRANSIENT_ERROR: + break; + case ERROR_CLASS_PERMANENT_ERROR: + // It's possible that SIZE failed because the path is a directory. + if (resource_type_ == RESOURCE_TYPE_UNKNOWN && + response.status_code != 550) { + return Stop(GetNetErrorCodeForFtpResponseCode(response.status_code)); + } + break; + default: + NOTREACHED(); + return Stop(ERR_UNEXPECTED); + } + + if (resource_type_ == RESOURCE_TYPE_FILE) + next_state_ = STATE_CTRL_WRITE_RETR; + else + next_state_ = STATE_CTRL_WRITE_CWD; + + return OK; +} + // CWD command int FtpNetworkTransaction::DoCtrlWriteCWD() { std::string command = "CWD " + GetRequestPathForFtpCommand(true); diff --git a/net/ftp/ftp_network_transaction.h b/net/ftp/ftp_network_transaction.h index 678308a..c4516a4 100644 --- a/net/ftp/ftp_network_transaction.h +++ b/net/ftp/ftp_network_transaction.h @@ -31,15 +31,16 @@ class FtpNetworkTransaction : public FtpTransaction { ClientSocketFactory* socket_factory); virtual ~FtpNetworkTransaction(); + virtual int Stop(int error); + virtual int RestartIgnoringLastError(CompletionCallback* callback); + // FtpTransaction methods: virtual int Start(const FtpRequestInfo* request_info, CompletionCallback* callback, const BoundNetLog& net_log); - virtual int Stop(int error); virtual int RestartWithAuth(const string16& username, const string16& password, CompletionCallback* callback); - virtual int RestartIgnoringLastError(CompletionCallback* callback); virtual int Read(IOBuffer* buf, int buf_len, CompletionCallback* callback); virtual const FtpResponseInfo* GetResponseInfo() const; virtual LoadState GetLoadState() const; @@ -87,6 +88,36 @@ class FtpNetworkTransaction : public FtpTransaction { RESOURCE_TYPE_DIRECTORY, }; + enum State { + // Control connection states: + STATE_CTRL_RESOLVE_HOST, + STATE_CTRL_RESOLVE_HOST_COMPLETE, + STATE_CTRL_CONNECT, + STATE_CTRL_CONNECT_COMPLETE, + STATE_CTRL_READ, + STATE_CTRL_READ_COMPLETE, + STATE_CTRL_WRITE, + STATE_CTRL_WRITE_COMPLETE, + STATE_CTRL_WRITE_USER, + STATE_CTRL_WRITE_PASS, + STATE_CTRL_WRITE_SYST, + STATE_CTRL_WRITE_TYPE, + STATE_CTRL_WRITE_EPSV, + STATE_CTRL_WRITE_PASV, + STATE_CTRL_WRITE_PWD, + STATE_CTRL_WRITE_RETR, + STATE_CTRL_WRITE_SIZE, + STATE_CTRL_WRITE_CWD, + STATE_CTRL_WRITE_LIST, + STATE_CTRL_WRITE_QUIT, + // Data connection states: + STATE_DATA_CONNECT, + STATE_DATA_CONNECT_COMPLETE, + STATE_DATA_READ, + STATE_DATA_READ_COMPLETE, + STATE_NONE + }; + // Resets the members of the transaction so it can be restarted. void ResetStateForRestart(); @@ -211,35 +242,6 @@ class FtpNetworkTransaction : public FtpTransaction { scoped_ptr<ClientSocket> ctrl_socket_; scoped_ptr<ClientSocket> data_socket_; - enum State { - // Control connection states: - STATE_CTRL_RESOLVE_HOST, - STATE_CTRL_RESOLVE_HOST_COMPLETE, - STATE_CTRL_CONNECT, - STATE_CTRL_CONNECT_COMPLETE, - STATE_CTRL_READ, - STATE_CTRL_READ_COMPLETE, - STATE_CTRL_WRITE, - STATE_CTRL_WRITE_COMPLETE, - STATE_CTRL_WRITE_USER, - STATE_CTRL_WRITE_PASS, - STATE_CTRL_WRITE_SYST, - STATE_CTRL_WRITE_TYPE, - STATE_CTRL_WRITE_EPSV, - STATE_CTRL_WRITE_PASV, - STATE_CTRL_WRITE_PWD, - STATE_CTRL_WRITE_RETR, - STATE_CTRL_WRITE_SIZE, - STATE_CTRL_WRITE_CWD, - STATE_CTRL_WRITE_LIST, - STATE_CTRL_WRITE_QUIT, - // Data connection states: - STATE_DATA_CONNECT, - STATE_DATA_CONNECT_COMPLETE, - STATE_DATA_READ, - STATE_DATA_READ_COMPLETE, - STATE_NONE - }; State next_state_; }; diff --git a/net/http/disk_cache_based_ssl_host_info.cc b/net/http/disk_cache_based_ssl_host_info.cc index 1b1dfaf..1a875cb 100644 --- a/net/http/disk_cache_based_ssl_host_info.cc +++ b/net/http/disk_cache_based_ssl_host_info.cc @@ -13,6 +13,24 @@ namespace net { +DiskCacheBasedSSLHostInfo::CallbackImpl::CallbackImpl( + const base::WeakPtr<DiskCacheBasedSSLHostInfo>& obj, + void (DiskCacheBasedSSLHostInfo::*meth) (int)) + : obj_(obj), + meth_(meth) { +} + +DiskCacheBasedSSLHostInfo::CallbackImpl::~CallbackImpl() {} + +void DiskCacheBasedSSLHostInfo::CallbackImpl::RunWithParams( + const Tuple1<int>& params) { + if (!obj_) { + delete this; + } else { + DispatchToMethod(obj_.get(), meth_, params); + } +} + DiskCacheBasedSSLHostInfo::DiskCacheBasedSSLHostInfo( const std::string& hostname, const SSLConfig& ssl_config, @@ -37,6 +55,35 @@ void DiskCacheBasedSSLHostInfo::Start() { DoLoop(OK); } +int DiskCacheBasedSSLHostInfo::WaitForDataReady(CompletionCallback* callback) { + DCHECK(CalledOnValidThread()); + DCHECK(state_ != GET_BACKEND); + + if (ready_) + return OK; + if (callback) { + DCHECK(!user_callback_); + user_callback_ = callback; + } + return ERR_IO_PENDING; +} + +void DiskCacheBasedSSLHostInfo::Persist() { + DCHECK(CalledOnValidThread()); + DCHECK(state_ != GET_BACKEND); + + DCHECK(new_data_.empty()); + CHECK(ready_); + DCHECK(user_callback_ == NULL); + new_data_ = Serialize(); + + if (!backend_) + return; + + state_ = CREATE; + DoLoop(OK); +} + DiskCacheBasedSSLHostInfo::~DiskCacheBasedSSLHostInfo() { DCHECK(!user_callback_); if (entry_) @@ -95,24 +142,6 @@ void DiskCacheBasedSSLHostInfo::DoLoop(int rv) { } while (rv != ERR_IO_PENDING && state_ != NONE); } -bool DiskCacheBasedSSLHostInfo::IsCallbackPending() const { - switch (state_) { - case GET_BACKEND_COMPLETE: - case OPEN_COMPLETE: - case READ_COMPLETE: - case CREATE_COMPLETE: - case WRITE_COMPLETE: - return true; - default: - return false; - } -} - -int DiskCacheBasedSSLHostInfo::DoGetBackend() { - state_ = GET_BACKEND_COMPLETE; - return http_cache_->GetBackend(callback_->backend_pointer(), callback_); -} - int DiskCacheBasedSSLHostInfo::DoGetBackendComplete(int rv) { if (rv == OK) { backend_ = callback_->backend(); @@ -123,11 +152,6 @@ int DiskCacheBasedSSLHostInfo::DoGetBackendComplete(int rv) { return OK; } -int DiskCacheBasedSSLHostInfo::DoOpen() { - state_ = OPEN_COMPLETE; - return backend_->OpenEntry(key(), callback_->entry_pointer(), callback_); -} - int DiskCacheBasedSSLHostInfo::DoOpenComplete(int rv) { if (rv == OK) { entry_ = callback_->entry(); @@ -139,6 +163,39 @@ int DiskCacheBasedSSLHostInfo::DoOpenComplete(int rv) { return OK; } +int DiskCacheBasedSSLHostInfo::DoReadComplete(int rv) { + if (rv > 0) + data_ = std::string(read_buffer_->data(), rv); + + state_ = WAIT_FOR_DATA_READY_DONE; + return OK; +} + +int DiskCacheBasedSSLHostInfo::DoWriteComplete(int rv) { + state_ = SET_DONE; + return OK; +} + +int DiskCacheBasedSSLHostInfo::DoCreateComplete(int rv) { + if (rv != OK) { + state_ = SET_DONE; + } else { + entry_ = callback_->entry(); + state_ = WRITE; + } + return OK; +} + +int DiskCacheBasedSSLHostInfo::DoGetBackend() { + state_ = GET_BACKEND_COMPLETE; + return http_cache_->GetBackend(callback_->backend_pointer(), callback_); +} + +int DiskCacheBasedSSLHostInfo::DoOpen() { + state_ = OPEN_COMPLETE; + return backend_->OpenEntry(key(), callback_->entry_pointer(), callback_); +} + int DiskCacheBasedSSLHostInfo::DoRead() { const int32 size = entry_->GetDataSize(0 /* index */); if (!size) { @@ -152,12 +209,19 @@ int DiskCacheBasedSSLHostInfo::DoRead() { size, callback_); } -int DiskCacheBasedSSLHostInfo::DoReadComplete(int rv) { - if (rv > 0) - data_ = std::string(read_buffer_->data(), rv); +int DiskCacheBasedSSLHostInfo::DoWrite() { + write_buffer_ = new IOBuffer(new_data_.size()); + memcpy(write_buffer_->data(), new_data_.data(), new_data_.size()); + state_ = WRITE_COMPLETE; - state_ = WAIT_FOR_DATA_READY_DONE; - return OK; + return entry_->WriteData(0 /* index */, 0 /* offset */, write_buffer_, + new_data_.size(), callback_, true /* truncate */); +} + +int DiskCacheBasedSSLHostInfo::DoCreate() { + DCHECK(entry_ == NULL); + state_ = CREATE_COMPLETE; + return backend_->CreateEntry(key(), callback_->entry_pointer(), callback_); } int DiskCacheBasedSSLHostInfo::WaitForDataReadyDone() { @@ -181,65 +245,6 @@ int DiskCacheBasedSSLHostInfo::WaitForDataReadyDone() { return OK; } -int DiskCacheBasedSSLHostInfo::WaitForDataReady(CompletionCallback* callback) { - DCHECK(CalledOnValidThread()); - DCHECK(state_ != GET_BACKEND); - - if (ready_) - return OK; - if (callback) { - DCHECK(!user_callback_); - user_callback_ = callback; - } - return ERR_IO_PENDING; -} - -void DiskCacheBasedSSLHostInfo::Persist() { - DCHECK(CalledOnValidThread()); - DCHECK(state_ != GET_BACKEND); - - DCHECK(new_data_.empty()); - CHECK(ready_); - DCHECK(user_callback_ == NULL); - new_data_ = Serialize(); - - if (!backend_) - return; - - state_ = CREATE; - DoLoop(OK); -} - -int DiskCacheBasedSSLHostInfo::DoCreate() { - DCHECK(entry_ == NULL); - state_ = CREATE_COMPLETE; - return backend_->CreateEntry(key(), callback_->entry_pointer(), callback_); -} - -int DiskCacheBasedSSLHostInfo::DoCreateComplete(int rv) { - if (rv != OK) { - state_ = SET_DONE; - } else { - entry_ = callback_->entry(); - state_ = WRITE; - } - return OK; -} - -int DiskCacheBasedSSLHostInfo::DoWrite() { - write_buffer_ = new IOBuffer(new_data_.size()); - memcpy(write_buffer_->data(), new_data_.data(), new_data_.size()); - state_ = WRITE_COMPLETE; - - return entry_->WriteData(0 /* index */, 0 /* offset */, write_buffer_, - new_data_.size(), callback_, true /* truncate */); -} - -int DiskCacheBasedSSLHostInfo::DoWriteComplete(int rv) { - state_ = SET_DONE; - return OK; -} - int DiskCacheBasedSSLHostInfo::SetDone() { if (entry_) entry_->Close(); @@ -248,4 +253,17 @@ int DiskCacheBasedSSLHostInfo::SetDone() { return OK; } +bool DiskCacheBasedSSLHostInfo::IsCallbackPending() const { + switch (state_) { + case GET_BACKEND_COMPLETE: + case OPEN_COMPLETE: + case READ_COMPLETE: + case CREATE_COMPLETE: + case WRITE_COMPLETE: + return true; + default: + return false; + } +} + } // namespace net diff --git a/net/http/disk_cache_based_ssl_host_info.h b/net/http/disk_cache_based_ssl_host_info.h index 2beb7e4..9d04ba0 100644 --- a/net/http/disk_cache_based_ssl_host_info.h +++ b/net/http/disk_cache_based_ssl_host_info.h @@ -52,29 +52,20 @@ class DiskCacheBasedSSLHostInfo : public SSLHostInfo, NONE, }; - ~DiskCacheBasedSSLHostInfo(); - class CallbackImpl : public CallbackRunner<Tuple1<int> > { public: CallbackImpl(const base::WeakPtr<DiskCacheBasedSSLHostInfo>& obj, - void (DiskCacheBasedSSLHostInfo::*meth) (int)) - : obj_(obj), - meth_(meth) { - } - - virtual void RunWithParams(const Tuple1<int>& params) { - if (!obj_) { - delete this; - } else { - DispatchToMethod(obj_.get(), meth_, params); - } - } + void (DiskCacheBasedSSLHostInfo::*meth) (int)); + virtual ~CallbackImpl(); disk_cache::Backend** backend_pointer() { return &backend_; } disk_cache::Entry** entry_pointer() { return &entry_; } disk_cache::Backend* backend() const { return backend_; } disk_cache::Entry* entry() const { return entry_; } + // CallbackRunner<Tuple1<int> >: + virtual void RunWithParams(const Tuple1<int>& params); + private: base::WeakPtr<DiskCacheBasedSSLHostInfo> obj_; void (DiskCacheBasedSSLHostInfo::*meth_) (int); @@ -83,6 +74,8 @@ class DiskCacheBasedSSLHostInfo : public SSLHostInfo, disk_cache::Entry* entry_; }; + virtual ~DiskCacheBasedSSLHostInfo(); + std::string key() const; void DoLoop(int rv); @@ -96,11 +89,12 @@ class DiskCacheBasedSSLHostInfo : public SSLHostInfo, int DoGetBackend(); int DoOpen(); int DoRead(); - int DoCreate(); int DoWrite(); + int DoCreate(); // WaitForDataReadyDone is the terminal state of the read operation. int WaitForDataReadyDone(); + // SetDone is the terminal state of the write operation. int SetDone(); diff --git a/net/http/http_auth_filter.cc b/net/http/http_auth_filter.cc index a61e7f7..2109e4d 100644 --- a/net/http/http_auth_filter.cc +++ b/net/http/http_auth_filter.cc @@ -24,21 +24,6 @@ HttpAuthFilterWhitelist::HttpAuthFilterWhitelist( HttpAuthFilterWhitelist::~HttpAuthFilterWhitelist() { } -void HttpAuthFilterWhitelist::SetWhitelist( - const std::string& server_whitelist) { - rules_.ParseFromString(server_whitelist); -} - -bool HttpAuthFilterWhitelist::IsValid(const GURL& url, - HttpAuth::Target target) const { - if ((target != HttpAuth::AUTH_SERVER) && (target != HttpAuth::AUTH_PROXY)) - return false; - // All proxies pass - if (target == HttpAuth::AUTH_PROXY) - return true; - return rules_.Matches(url); -} - // Add a new domain |filter| to the whitelist, if it's not already there bool HttpAuthFilterWhitelist::AddFilter(const std::string& filter, HttpAuth::Target target) { @@ -55,4 +40,19 @@ void HttpAuthFilterWhitelist::AddRuleToBypassLocal() { rules_.AddRuleToBypassLocal(); } +bool HttpAuthFilterWhitelist::IsValid(const GURL& url, + HttpAuth::Target target) const { + if ((target != HttpAuth::AUTH_SERVER) && (target != HttpAuth::AUTH_PROXY)) + return false; + // All proxies pass + if (target == HttpAuth::AUTH_PROXY) + return true; + return rules_.Matches(url); +} + +void HttpAuthFilterWhitelist::SetWhitelist( + const std::string& server_whitelist) { + rules_.ParseFromString(server_whitelist); +} + } // namespace net diff --git a/net/http/http_auth_filter.h b/net/http/http_auth_filter.h index 334bc91..81d414c 100644 --- a/net/http/http_auth_filter.h +++ b/net/http/http_auth_filter.h @@ -37,9 +37,6 @@ class HttpAuthFilterWhitelist : public HttpAuthFilter { explicit HttpAuthFilterWhitelist(const std::string& server_whitelist); virtual ~HttpAuthFilterWhitelist(); - // HttpAuthFilter methods: - virtual bool IsValid(const GURL& url, HttpAuth::Target target) const; - // Adds an individual URL |filter| to the list, of the specified |target|. bool AddFilter(const std::string& filter, HttpAuth::Target target); @@ -48,6 +45,9 @@ class HttpAuthFilterWhitelist : public HttpAuthFilter { const ProxyBypassRules& rules() const { return rules_; } + // HttpAuthFilter methods: + virtual bool IsValid(const GURL& url, HttpAuth::Target target) const; + private: // Installs the whitelist. // |server_whitelist| is parsed by ProxyBypassRules. diff --git a/net/http/http_auth_handler_digest.cc b/net/http/http_auth_handler_digest.cc index 7c5526c..e8cb819 100644 --- a/net/http/http_auth_handler_digest.cc +++ b/net/http/http_auth_handler_digest.cc @@ -74,46 +74,60 @@ std::string HttpAuthHandlerDigest::FixedNonceGenerator::GenerateNonce() const { return nonce_; } -// static -std::string HttpAuthHandlerDigest::QopToString(QualityOfProtection qop) { - switch (qop) { - case QOP_UNSPECIFIED: - return ""; - case QOP_AUTH: - return "auth"; - default: - NOTREACHED(); - return ""; - } +HttpAuthHandlerDigest::Factory::Factory() + : nonce_generator_(new DynamicNonceGenerator()) { } -// static -std::string HttpAuthHandlerDigest::AlgorithmToString( - DigestAlgorithm algorithm) { - switch (algorithm) { - case ALGORITHM_UNSPECIFIED: - return ""; - case ALGORITHM_MD5: - return "MD5"; - case ALGORITHM_MD5_SESS: - return "MD5-sess"; - default: - NOTREACHED(); - return ""; - } +HttpAuthHandlerDigest::Factory::~Factory() { } -HttpAuthHandlerDigest::HttpAuthHandlerDigest( - int nonce_count, const NonceGenerator* nonce_generator) - : stale_(false), - algorithm_(ALGORITHM_UNSPECIFIED), - qop_(QOP_UNSPECIFIED), - nonce_count_(nonce_count), - nonce_generator_(nonce_generator) { - DCHECK(nonce_generator_); +void HttpAuthHandlerDigest::Factory::set_nonce_generator( + const NonceGenerator* nonce_generator) { + nonce_generator_.reset(nonce_generator); } -HttpAuthHandlerDigest::~HttpAuthHandlerDigest() { +int HttpAuthHandlerDigest::Factory::CreateAuthHandler( + HttpAuth::ChallengeTokenizer* challenge, + HttpAuth::Target target, + const GURL& origin, + CreateReason reason, + int digest_nonce_count, + const BoundNetLog& net_log, + scoped_ptr<HttpAuthHandler>* handler) { + // TODO(cbentzel): Move towards model of parsing in the factory + // method and only constructing when valid. + scoped_ptr<HttpAuthHandler> tmp_handler( + new HttpAuthHandlerDigest(digest_nonce_count, nonce_generator_.get())); + if (!tmp_handler->InitFromChallenge(challenge, target, origin, net_log)) + return ERR_INVALID_RESPONSE; + handler->swap(tmp_handler); + return OK; +} + +HttpAuth::AuthorizationResult HttpAuthHandlerDigest::HandleAnotherChallenge( + HttpAuth::ChallengeTokenizer* challenge) { + // Even though Digest is not connection based, a "second round" is parsed + // to differentiate between stale and rejected responses. + // Note that the state of the current handler is not mutated - this way if + // there is a rejection the realm hasn't changed. + if (!LowerCaseEqualsASCII(challenge->scheme(), "digest")) + return HttpAuth::AUTHORIZATION_RESULT_INVALID; + + HttpUtil::NameValuePairsIterator parameters = challenge->param_pairs(); + + // Try to find the "stale" value. + while (parameters.GetNext()) { + if (!LowerCaseEqualsASCII(parameters.name(), "stale")) + continue; + if (LowerCaseEqualsASCII(parameters.value(), "true")) + return HttpAuth::AUTHORIZATION_RESULT_STALE; + } + + return HttpAuth::AUTHORIZATION_RESULT_REJECT; +} + +bool HttpAuthHandlerDigest::Init(HttpAuth::ChallengeTokenizer* challenge) { + return ParseChallenge(challenge); } int HttpAuthHandlerDigest::GenerateAuthTokenImpl( @@ -138,112 +152,17 @@ int HttpAuthHandlerDigest::GenerateAuthTokenImpl( return OK; } -void HttpAuthHandlerDigest::GetRequestMethodAndPath( - const HttpRequestInfo* request, - std::string* method, - std::string* path) const { - DCHECK(request); - - const GURL& url = request->url; - - if (target_ == HttpAuth::AUTH_PROXY && url.SchemeIs("https")) { - *method = "CONNECT"; - *path = GetHostAndPort(url); - } else { - *method = request->method; - *path = HttpUtil::PathForRequest(url); - } -} - -std::string HttpAuthHandlerDigest::AssembleResponseDigest( - const std::string& method, - const std::string& path, - const string16& username, - const string16& password, - const std::string& cnonce, - const std::string& nc) const { - // ha1 = MD5(A1) - // TODO(eroman): is this the right encoding? - std::string ha1 = MD5String(UTF16ToUTF8(username) + ":" + realm_ + ":" + - UTF16ToUTF8(password)); - if (algorithm_ == HttpAuthHandlerDigest::ALGORITHM_MD5_SESS) - ha1 = MD5String(ha1 + ":" + nonce_ + ":" + cnonce); - - // ha2 = MD5(A2) - // TODO(eroman): need to add MD5(req-entity-body) for qop=auth-int. - std::string ha2 = MD5String(method + ":" + path); - - std::string nc_part; - if (qop_ != HttpAuthHandlerDigest::QOP_UNSPECIFIED) { - nc_part = nc + ":" + cnonce + ":" + QopToString(qop_) + ":"; - } - - return MD5String(ha1 + ":" + nonce_ + ":" + nc_part + ha2); -} - -std::string HttpAuthHandlerDigest::AssembleCredentials( - const std::string& method, - const std::string& path, - const string16& username, - const string16& password, - const std::string& cnonce, - int nonce_count) const { - // the nonce-count is an 8 digit hex string. - std::string nc = base::StringPrintf("%08x", nonce_count); - - // TODO(eroman): is this the right encoding? - std::string authorization = (std::string("Digest username=") + - HttpUtil::Quote(UTF16ToUTF8(username))); - authorization += ", realm=" + HttpUtil::Quote(realm_); - authorization += ", nonce=" + HttpUtil::Quote(nonce_); - authorization += ", uri=" + HttpUtil::Quote(path); - - if (algorithm_ != ALGORITHM_UNSPECIFIED) { - authorization += ", algorithm=" + AlgorithmToString(algorithm_); - } - std::string response = AssembleResponseDigest(method, path, username, - password, cnonce, nc); - // No need to call HttpUtil::Quote() as the response digest cannot contain - // any characters needing to be escaped. - authorization += ", response=\"" + response + "\""; - - if (!opaque_.empty()) { - authorization += ", opaque=" + HttpUtil::Quote(opaque_); - } - if (qop_ != QOP_UNSPECIFIED) { - // TODO(eroman): Supposedly IIS server requires quotes surrounding qop. - authorization += ", qop=" + QopToString(qop_); - authorization += ", nc=" + nc; - authorization += ", cnonce=" + HttpUtil::Quote(cnonce); - } - - return authorization; -} - -bool HttpAuthHandlerDigest::Init(HttpAuth::ChallengeTokenizer* challenge) { - return ParseChallenge(challenge); +HttpAuthHandlerDigest::HttpAuthHandlerDigest( + int nonce_count, const NonceGenerator* nonce_generator) + : stale_(false), + algorithm_(ALGORITHM_UNSPECIFIED), + qop_(QOP_UNSPECIFIED), + nonce_count_(nonce_count), + nonce_generator_(nonce_generator) { + DCHECK(nonce_generator_); } -HttpAuth::AuthorizationResult HttpAuthHandlerDigest::HandleAnotherChallenge( - HttpAuth::ChallengeTokenizer* challenge) { - // Even though Digest is not connection based, a "second round" is parsed - // to differentiate between stale and rejected responses. - // Note that the state of the current handler is not mutated - this way if - // there is a rejection the realm hasn't changed. - if (!LowerCaseEqualsASCII(challenge->scheme(), "digest")) - return HttpAuth::AUTHORIZATION_RESULT_INVALID; - - HttpUtil::NameValuePairsIterator parameters = challenge->param_pairs(); - - // Try to find the "stale" value. - while (parameters.GetNext()) { - if (!LowerCaseEqualsASCII(parameters.name(), "stale")) - continue; - if (LowerCaseEqualsASCII(parameters.value(), "true")) - return HttpAuth::AUTHORIZATION_RESULT_STALE; - } - - return HttpAuth::AUTHORIZATION_RESULT_REJECT; +HttpAuthHandlerDigest::~HttpAuthHandlerDigest() { } // The digest challenge header looks like: @@ -342,34 +261,115 @@ bool HttpAuthHandlerDigest::ParseChallengeProperty(const std::string& name, return true; } -HttpAuthHandlerDigest::Factory::Factory() - : nonce_generator_(new DynamicNonceGenerator()) { +// static +std::string HttpAuthHandlerDigest::QopToString(QualityOfProtection qop) { + switch (qop) { + case QOP_UNSPECIFIED: + return ""; + case QOP_AUTH: + return "auth"; + default: + NOTREACHED(); + return ""; + } } -HttpAuthHandlerDigest::Factory::~Factory() { +// static +std::string HttpAuthHandlerDigest::AlgorithmToString( + DigestAlgorithm algorithm) { + switch (algorithm) { + case ALGORITHM_UNSPECIFIED: + return ""; + case ALGORITHM_MD5: + return "MD5"; + case ALGORITHM_MD5_SESS: + return "MD5-sess"; + default: + NOTREACHED(); + return ""; + } } -void HttpAuthHandlerDigest::Factory::set_nonce_generator( - const NonceGenerator* nonce_generator) { - nonce_generator_.reset(nonce_generator); +void HttpAuthHandlerDigest::GetRequestMethodAndPath( + const HttpRequestInfo* request, + std::string* method, + std::string* path) const { + DCHECK(request); + + const GURL& url = request->url; + + if (target_ == HttpAuth::AUTH_PROXY && url.SchemeIs("https")) { + *method = "CONNECT"; + *path = GetHostAndPort(url); + } else { + *method = request->method; + *path = HttpUtil::PathForRequest(url); + } } -int HttpAuthHandlerDigest::Factory::CreateAuthHandler( - HttpAuth::ChallengeTokenizer* challenge, - HttpAuth::Target target, - const GURL& origin, - CreateReason reason, - int digest_nonce_count, - const BoundNetLog& net_log, - scoped_ptr<HttpAuthHandler>* handler) { - // TODO(cbentzel): Move towards model of parsing in the factory - // method and only constructing when valid. - scoped_ptr<HttpAuthHandler> tmp_handler( - new HttpAuthHandlerDigest(digest_nonce_count, nonce_generator_.get())); - if (!tmp_handler->InitFromChallenge(challenge, target, origin, net_log)) - return ERR_INVALID_RESPONSE; - handler->swap(tmp_handler); - return OK; +std::string HttpAuthHandlerDigest::AssembleResponseDigest( + const std::string& method, + const std::string& path, + const string16& username, + const string16& password, + const std::string& cnonce, + const std::string& nc) const { + // ha1 = MD5(A1) + // TODO(eroman): is this the right encoding? + std::string ha1 = MD5String(UTF16ToUTF8(username) + ":" + realm_ + ":" + + UTF16ToUTF8(password)); + if (algorithm_ == HttpAuthHandlerDigest::ALGORITHM_MD5_SESS) + ha1 = MD5String(ha1 + ":" + nonce_ + ":" + cnonce); + + // ha2 = MD5(A2) + // TODO(eroman): need to add MD5(req-entity-body) for qop=auth-int. + std::string ha2 = MD5String(method + ":" + path); + + std::string nc_part; + if (qop_ != HttpAuthHandlerDigest::QOP_UNSPECIFIED) { + nc_part = nc + ":" + cnonce + ":" + QopToString(qop_) + ":"; + } + + return MD5String(ha1 + ":" + nonce_ + ":" + nc_part + ha2); +} + +std::string HttpAuthHandlerDigest::AssembleCredentials( + const std::string& method, + const std::string& path, + const string16& username, + const string16& password, + const std::string& cnonce, + int nonce_count) const { + // the nonce-count is an 8 digit hex string. + std::string nc = base::StringPrintf("%08x", nonce_count); + + // TODO(eroman): is this the right encoding? + std::string authorization = (std::string("Digest username=") + + HttpUtil::Quote(UTF16ToUTF8(username))); + authorization += ", realm=" + HttpUtil::Quote(realm_); + authorization += ", nonce=" + HttpUtil::Quote(nonce_); + authorization += ", uri=" + HttpUtil::Quote(path); + + if (algorithm_ != ALGORITHM_UNSPECIFIED) { + authorization += ", algorithm=" + AlgorithmToString(algorithm_); + } + std::string response = AssembleResponseDigest(method, path, username, + password, cnonce, nc); + // No need to call HttpUtil::Quote() as the response digest cannot contain + // any characters needing to be escaped. + authorization += ", response=\"" + response + "\""; + + if (!opaque_.empty()) { + authorization += ", opaque=" + HttpUtil::Quote(opaque_); + } + if (qop_ != QOP_UNSPECIFIED) { + // TODO(eroman): Supposedly IIS server requires quotes surrounding qop. + authorization += ", qop=" + QopToString(qop_); + authorization += ", nc=" + nc; + authorization += ", cnonce=" + HttpUtil::Quote(cnonce); + } + + return authorization; } } // namespace net diff --git a/net/http/http_auth_handler_digest.h b/net/http/http_auth_handler_digest.h index c319f5d..fca77e4 100644 --- a/net/http/http_auth_handler_digest.h +++ b/net/http/http_auth_handler_digest.h @@ -62,6 +62,9 @@ class HttpAuthHandlerDigest : public HttpAuthHandler { Factory(); virtual ~Factory(); + // This factory owns the passed in |nonce_generator|. + void set_nonce_generator(const NonceGenerator* nonce_generator); + virtual int CreateAuthHandler(HttpAuth::ChallengeTokenizer* challenge, HttpAuth::Target target, const GURL& origin, @@ -70,9 +73,6 @@ class HttpAuthHandlerDigest : public HttpAuthHandler { const BoundNetLog& net_log, scoped_ptr<HttpAuthHandler>* handler); - // This factory owns the passed in |nonce_generator|. - void set_nonce_generator(const NonceGenerator* nonce_generator); - private: scoped_ptr<const NonceGenerator> nonce_generator_; }; diff --git a/net/http/http_auth_handler_mock.cc b/net/http/http_auth_handler_mock.cc index b4e2268..aad1bd1 100644 --- a/net/http/http_auth_handler_mock.cc +++ b/net/http/http_auth_handler_mock.cc @@ -71,13 +71,6 @@ void HttpAuthHandlerMock::SetGenerateExpectation(bool async, int rv) { generate_rv_ = rv; } -bool HttpAuthHandlerMock::Init(HttpAuth::ChallengeTokenizer* challenge) { - auth_scheme_ = HttpAuth::AUTH_SCHEME_MOCK; - score_ = 1; - properties_ = connection_based_ ? IS_CONNECTION_BASED : 0; - return true; -} - HttpAuth::AuthorizationResult HttpAuthHandlerMock::HandleAnotherChallenge( HttpAuth::ChallengeTokenizer* challenge) { if (!is_connection_based()) @@ -87,6 +80,17 @@ HttpAuth::AuthorizationResult HttpAuthHandlerMock::HandleAnotherChallenge( return HttpAuth::AUTHORIZATION_RESULT_ACCEPT; } +bool HttpAuthHandlerMock::NeedsIdentity() { + return first_round_; +} + +bool HttpAuthHandlerMock::Init(HttpAuth::ChallengeTokenizer* challenge) { + auth_scheme_ = HttpAuth::AUTH_SCHEME_MOCK; + score_ = 1; + properties_ = connection_based_ ? IS_CONNECTION_BASED : 0; + return true; +} + int HttpAuthHandlerMock::GenerateAuthTokenImpl(const string16* username, const string16* password, const HttpRequestInfo* request, diff --git a/net/http/http_auth_handler_mock.h b/net/http/http_auth_handler_mock.h index bef8b2b..473ca2e 100644 --- a/net/http/http_auth_handler_mock.h +++ b/net/http/http_auth_handler_mock.h @@ -29,32 +29,6 @@ class HttpAuthHandlerMock : public HttpAuthHandler { RESOLVE_TESTED, }; - HttpAuthHandlerMock(); - - virtual ~HttpAuthHandlerMock(); - - void SetResolveExpectation(Resolve resolve); - - virtual bool NeedsCanonicalName(); - - virtual int ResolveCanonicalName(HostResolver* host_resolver, - CompletionCallback* callback); - - virtual bool NeedsIdentity() { return first_round_; } - - void SetGenerateExpectation(bool async, int rv); - - void set_connection_based(bool connection_based) { - connection_based_ = connection_based; - } - - const GURL& request_url() const { - return request_url_; - } - - HttpAuth::AuthorizationResult HandleAnotherChallenge( - HttpAuth::ChallengeTokenizer* challenge); - // The Factory class simply returns the same handler each time // CreateAuthHandler is called. class Factory : public HttpAuthHandlerFactory { @@ -68,6 +42,7 @@ class HttpAuthHandlerMock : public HttpAuthHandler { do_init_from_challenge_ = do_init_from_challenge; } + // HttpAuthHandlerFactory: virtual int CreateAuthHandler(HttpAuth::ChallengeTokenizer* challenge, HttpAuth::Target target, const GURL& origin, @@ -81,6 +56,33 @@ class HttpAuthHandlerMock : public HttpAuthHandler { bool do_init_from_challenge_; }; + HttpAuthHandlerMock(); + + virtual ~HttpAuthHandlerMock(); + + void SetResolveExpectation(Resolve resolve); + + virtual bool NeedsCanonicalName(); + + virtual int ResolveCanonicalName(HostResolver* host_resolver, + CompletionCallback* callback); + + + void SetGenerateExpectation(bool async, int rv); + + void set_connection_based(bool connection_based) { + connection_based_ = connection_based; + } + + const GURL& request_url() const { + return request_url_; + } + + // HttpAuthHandler: + virtual HttpAuth::AuthorizationResult HandleAnotherChallenge( + HttpAuth::ChallengeTokenizer* challenge); + virtual bool NeedsIdentity(); + protected: virtual bool Init(HttpAuth::ChallengeTokenizer* challenge); diff --git a/net/http/http_auth_handler_negotiate.cc b/net/http/http_auth_handler_negotiate.cc index cedd282..a96902d 100644 --- a/net/http/http_auth_handler_negotiate.cc +++ b/net/http/http_auth_handler_negotiate.cc @@ -16,6 +16,68 @@ namespace net { +HttpAuthHandlerNegotiate::Factory::Factory() + : disable_cname_lookup_(false), + use_port_(false), +#if defined(OS_WIN) + max_token_length_(0), + first_creation_(true), + is_unsupported_(false), +#endif + auth_library_(NULL) { +} + +HttpAuthHandlerNegotiate::Factory::~Factory() { +} + +void HttpAuthHandlerNegotiate::Factory::set_host_resolver( + HostResolver* resolver) { + resolver_ = resolver; +} + +int HttpAuthHandlerNegotiate::Factory::CreateAuthHandler( + HttpAuth::ChallengeTokenizer* challenge, + HttpAuth::Target target, + const GURL& origin, + CreateReason reason, + int digest_nonce_count, + const BoundNetLog& net_log, + scoped_ptr<HttpAuthHandler>* handler) { +#if defined(OS_WIN) + if (is_unsupported_ || reason == CREATE_PREEMPTIVE) + return ERR_UNSUPPORTED_AUTH_SCHEME; + if (max_token_length_ == 0) { + int rv = DetermineMaxTokenLength(auth_library_.get(), NEGOSSP_NAME, + &max_token_length_); + if (rv == ERR_UNSUPPORTED_AUTH_SCHEME) + is_unsupported_ = true; + if (rv != OK) + return rv; + } + // TODO(cbentzel): Move towards model of parsing in the factory + // method and only constructing when valid. + scoped_ptr<HttpAuthHandler> tmp_handler( + new HttpAuthHandlerNegotiate(auth_library_.get(), max_token_length_, + url_security_manager(), resolver_, + disable_cname_lookup_, use_port_)); + if (!tmp_handler->InitFromChallenge(challenge, target, origin, net_log)) + return ERR_INVALID_RESPONSE; + handler->swap(tmp_handler); + return OK; +#elif defined(OS_POSIX) + // TODO(ahendrickson): Move towards model of parsing in the factory + // method and only constructing when valid. + scoped_ptr<HttpAuthHandler> tmp_handler( + new HttpAuthHandlerNegotiate(auth_library_.get(), url_security_manager(), + resolver_, disable_cname_lookup_, + use_port_)); + if (!tmp_handler->InitFromChallenge(challenge, target, origin, net_log)) + return ERR_INVALID_RESPONSE; + handler->swap(tmp_handler); + return OK; +#endif +} + HttpAuthHandlerNegotiate::HttpAuthHandlerNegotiate( AuthLibrary* auth_library, #if defined(OS_WIN) @@ -46,88 +108,6 @@ HttpAuthHandlerNegotiate::HttpAuthHandlerNegotiate( HttpAuthHandlerNegotiate::~HttpAuthHandlerNegotiate() { } -int HttpAuthHandlerNegotiate::GenerateAuthTokenImpl( - const string16* username, - const string16* password, - const HttpRequestInfo* request, - CompletionCallback* callback, - std::string* auth_token) { - DCHECK(user_callback_ == NULL); - DCHECK((username == NULL) == (password == NULL)); - DCHECK(auth_token_ == NULL); - auth_token_ = auth_token; - if (already_called_) { - DCHECK((!has_username_and_password_ && username == NULL) || - (has_username_and_password_ && *username == username_ && - *password == password_)); - next_state_ = STATE_GENERATE_AUTH_TOKEN; - } else { - already_called_ = true; - if (username) { - has_username_and_password_ = true; - username_ = *username; - password_ = *password; - } - next_state_ = STATE_RESOLVE_CANONICAL_NAME; - } - int rv = DoLoop(OK); - if (rv == ERR_IO_PENDING) - user_callback_ = callback; - return rv; -} - -// The Negotiate challenge header looks like: -// WWW-Authenticate: NEGOTIATE auth-data -bool HttpAuthHandlerNegotiate::Init(HttpAuth::ChallengeTokenizer* challenge) { -#if defined(OS_POSIX) - if (!auth_system_.Init()) { - VLOG(1) << "can't initialize GSSAPI library"; - return false; - } - // GSSAPI does not provide a way to enter username/password to - // obtain a TGT. If the default credentials are not allowed for - // a particular site (based on whitelist), fall back to a - // different scheme. - if (!AllowsDefaultCredentials()) - return false; -#endif - if (CanDelegate()) - auth_system_.Delegate(); - auth_scheme_ = HttpAuth::AUTH_SCHEME_NEGOTIATE; - score_ = 4; - properties_ = ENCRYPTS_IDENTITY | IS_CONNECTION_BASED; - HttpAuth::AuthorizationResult auth_result = - auth_system_.ParseChallenge(challenge); - return (auth_result == HttpAuth::AUTHORIZATION_RESULT_ACCEPT); -} - -HttpAuth::AuthorizationResult HttpAuthHandlerNegotiate::HandleAnotherChallenge( - HttpAuth::ChallengeTokenizer* challenge) { - return auth_system_.ParseChallenge(challenge); -} - -// Require identity on first pass instead of second. -bool HttpAuthHandlerNegotiate::NeedsIdentity() { - return auth_system_.NeedsIdentity(); -} - -bool HttpAuthHandlerNegotiate::AllowsDefaultCredentials() { - if (target_ == HttpAuth::AUTH_PROXY) - return true; - if (!url_security_manager_) - return false; - return url_security_manager_->CanUseDefaultCredentials(origin_); -} - -bool HttpAuthHandlerNegotiate::CanDelegate() const { - // TODO(cbentzel): Should delegation be allowed on proxies? - if (target_ == HttpAuth::AUTH_PROXY) - return false; - if (!url_security_manager_) - return false; - return url_security_manager_->CanDelegate(origin_); -} - std::wstring HttpAuthHandlerNegotiate::CreateSPN( const AddressList& address_list, const GURL& origin) { // Kerberos Web Server SPNs are in the form HTTP/<host>:<port> through SSPI, @@ -177,6 +157,93 @@ std::wstring HttpAuthHandlerNegotiate::CreateSPN( } } +HttpAuth::AuthorizationResult HttpAuthHandlerNegotiate::HandleAnotherChallenge( + HttpAuth::ChallengeTokenizer* challenge) { + return auth_system_.ParseChallenge(challenge); +} + +// Require identity on first pass instead of second. +bool HttpAuthHandlerNegotiate::NeedsIdentity() { + return auth_system_.NeedsIdentity(); +} + +bool HttpAuthHandlerNegotiate::AllowsDefaultCredentials() { + if (target_ == HttpAuth::AUTH_PROXY) + return true; + if (!url_security_manager_) + return false; + return url_security_manager_->CanUseDefaultCredentials(origin_); +} + +// The Negotiate challenge header looks like: +// WWW-Authenticate: NEGOTIATE auth-data +bool HttpAuthHandlerNegotiate::Init(HttpAuth::ChallengeTokenizer* challenge) { +#if defined(OS_POSIX) + if (!auth_system_.Init()) { + VLOG(1) << "can't initialize GSSAPI library"; + return false; + } + // GSSAPI does not provide a way to enter username/password to + // obtain a TGT. If the default credentials are not allowed for + // a particular site (based on whitelist), fall back to a + // different scheme. + if (!AllowsDefaultCredentials()) + return false; +#endif + if (CanDelegate()) + auth_system_.Delegate(); + auth_scheme_ = HttpAuth::AUTH_SCHEME_NEGOTIATE; + score_ = 4; + properties_ = ENCRYPTS_IDENTITY | IS_CONNECTION_BASED; + HttpAuth::AuthorizationResult auth_result = + auth_system_.ParseChallenge(challenge); + return (auth_result == HttpAuth::AUTHORIZATION_RESULT_ACCEPT); +} + +int HttpAuthHandlerNegotiate::GenerateAuthTokenImpl( + const string16* username, + const string16* password, + const HttpRequestInfo* request, + CompletionCallback* callback, + std::string* auth_token) { + DCHECK(user_callback_ == NULL); + DCHECK((username == NULL) == (password == NULL)); + DCHECK(auth_token_ == NULL); + auth_token_ = auth_token; + if (already_called_) { + DCHECK((!has_username_and_password_ && username == NULL) || + (has_username_and_password_ && *username == username_ && + *password == password_)); + next_state_ = STATE_GENERATE_AUTH_TOKEN; + } else { + already_called_ = true; + if (username) { + has_username_and_password_ = true; + username_ = *username; + password_ = *password; + } + next_state_ = STATE_RESOLVE_CANONICAL_NAME; + } + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) + user_callback_ = callback; + return rv; +} + +void HttpAuthHandlerNegotiate::OnIOComplete(int result) { + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) + DoCallback(rv); +} + +void HttpAuthHandlerNegotiate::DoCallback(int rv) { + DCHECK(rv != ERR_IO_PENDING); + DCHECK(user_callback_); + CompletionCallback* callback = user_callback_; + user_callback_ = NULL; + callback->Run(rv); +} + int HttpAuthHandlerNegotiate::DoLoop(int result) { DCHECK(next_state_ != STATE_NONE); @@ -253,80 +320,13 @@ int HttpAuthHandlerNegotiate::DoGenerateAuthTokenComplete(int rv) { return rv; } -void HttpAuthHandlerNegotiate::OnIOComplete(int result) { - int rv = DoLoop(result); - if (rv != ERR_IO_PENDING) - DoCallback(rv); -} - -void HttpAuthHandlerNegotiate::DoCallback(int rv) { - DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_callback_); - CompletionCallback* callback = user_callback_; - user_callback_ = NULL; - callback->Run(rv); -} - -HttpAuthHandlerNegotiate::Factory::Factory() - : disable_cname_lookup_(false), - use_port_(false), -#if defined(OS_WIN) - max_token_length_(0), - first_creation_(true), - is_unsupported_(false), -#endif - auth_library_(NULL) { -} - -HttpAuthHandlerNegotiate::Factory::~Factory() { -} - -void HttpAuthHandlerNegotiate::Factory::set_host_resolver( - HostResolver* resolver) { - resolver_ = resolver; -} - -int HttpAuthHandlerNegotiate::Factory::CreateAuthHandler( - HttpAuth::ChallengeTokenizer* challenge, - HttpAuth::Target target, - const GURL& origin, - CreateReason reason, - int digest_nonce_count, - const BoundNetLog& net_log, - scoped_ptr<HttpAuthHandler>* handler) { -#if defined(OS_WIN) - if (is_unsupported_ || reason == CREATE_PREEMPTIVE) - return ERR_UNSUPPORTED_AUTH_SCHEME; - if (max_token_length_ == 0) { - int rv = DetermineMaxTokenLength(auth_library_.get(), NEGOSSP_NAME, - &max_token_length_); - if (rv == ERR_UNSUPPORTED_AUTH_SCHEME) - is_unsupported_ = true; - if (rv != OK) - return rv; - } - // TODO(cbentzel): Move towards model of parsing in the factory - // method and only constructing when valid. - scoped_ptr<HttpAuthHandler> tmp_handler( - new HttpAuthHandlerNegotiate(auth_library_.get(), max_token_length_, - url_security_manager(), resolver_, - disable_cname_lookup_, use_port_)); - if (!tmp_handler->InitFromChallenge(challenge, target, origin, net_log)) - return ERR_INVALID_RESPONSE; - handler->swap(tmp_handler); - return OK; -#elif defined(OS_POSIX) - // TODO(ahendrickson): Move towards model of parsing in the factory - // method and only constructing when valid. - scoped_ptr<HttpAuthHandler> tmp_handler( - new HttpAuthHandlerNegotiate(auth_library_.get(), url_security_manager(), - resolver_, disable_cname_lookup_, - use_port_)); - if (!tmp_handler->InitFromChallenge(challenge, target, origin, net_log)) - return ERR_INVALID_RESPONSE; - handler->swap(tmp_handler); - return OK; -#endif +bool HttpAuthHandlerNegotiate::CanDelegate() const { + // TODO(cbentzel): Should delegation be allowed on proxies? + if (target_ == HttpAuth::AUTH_PROXY) + return false; + if (!url_security_manager_) + return false; + return url_security_manager_->CanDelegate(origin_); } } // namespace net diff --git a/net/http/http_auth_handler_negotiate.h b/net/http/http_auth_handler_negotiate.h index a19d182..de2a6e6 100644 --- a/net/http/http_auth_handler_negotiate.h +++ b/net/http/http_auth_handler_negotiate.h @@ -64,6 +64,12 @@ class HttpAuthHandlerNegotiate : public HttpAuthHandler { void set_host_resolver(HostResolver* host_resolver); + // Sets the system library to use, thereby assuming ownership of + // |auth_library|. + void set_library(AuthLibrary* auth_library) { + auth_library_.reset(auth_library); + } + virtual int CreateAuthHandler(HttpAuth::ChallengeTokenizer* challenge, HttpAuth::Target target, const GURL& origin, @@ -72,12 +78,6 @@ class HttpAuthHandlerNegotiate : public HttpAuthHandler { const BoundNetLog& net_log, scoped_ptr<HttpAuthHandler>* handler); - // Sets the system library to use, thereby assuming ownership of - // |auth_library|. - void set_library(AuthLibrary* auth_library) { - auth_library_.reset(auth_library); - } - private: bool disable_cname_lookup_; bool use_port_; @@ -101,17 +101,16 @@ class HttpAuthHandlerNegotiate : public HttpAuthHandler { virtual ~HttpAuthHandlerNegotiate(); - virtual bool NeedsIdentity(); - - virtual bool AllowsDefaultCredentials(); - - virtual HttpAuth::AuthorizationResult HandleAnotherChallenge( - HttpAuth::ChallengeTokenizer* challenge); - // These are public for unit tests std::wstring CreateSPN(const AddressList& address_list, const GURL& orign); const std::wstring& spn() const { return spn_; } + // HttpAuthHandler: + virtual HttpAuth::AuthorizationResult HandleAnotherChallenge( + HttpAuth::ChallengeTokenizer* challenge); + virtual bool NeedsIdentity(); + virtual bool AllowsDefaultCredentials(); + protected: virtual bool Init(HttpAuth::ChallengeTokenizer* challenge); diff --git a/net/http/http_auth_handler_ntlm.cc b/net/http/http_auth_handler_ntlm.cc index c3e44ba..5090918 100644 --- a/net/http/http_auth_handler_ntlm.cc +++ b/net/http/http_auth_handler_ntlm.cc @@ -15,6 +15,19 @@ namespace net { +HttpAuth::AuthorizationResult HttpAuthHandlerNTLM::HandleAnotherChallenge( + HttpAuth::ChallengeTokenizer* challenge) { + return ParseChallenge(challenge, false); +} + +bool HttpAuthHandlerNTLM::Init(HttpAuth::ChallengeTokenizer* tok) { + auth_scheme_ = HttpAuth::AUTH_SCHEME_NTLM; + score_ = 3; + properties_ = ENCRYPTS_IDENTITY | IS_CONNECTION_BASED; + + return ParseChallenge(tok, true) == HttpAuth::AUTHORIZATION_RESULT_ACCEPT; +} + int HttpAuthHandlerNTLM::GenerateAuthTokenImpl( const string16* username, const string16* password, @@ -92,19 +105,6 @@ int HttpAuthHandlerNTLM::GenerateAuthTokenImpl( #endif } -bool HttpAuthHandlerNTLM::Init(HttpAuth::ChallengeTokenizer* tok) { - auth_scheme_ = HttpAuth::AUTH_SCHEME_NTLM; - score_ = 3; - properties_ = ENCRYPTS_IDENTITY | IS_CONNECTION_BASED; - - return ParseChallenge(tok, true) == HttpAuth::AUTHORIZATION_RESULT_ACCEPT; -} - -HttpAuth::AuthorizationResult HttpAuthHandlerNTLM::HandleAnotherChallenge( - HttpAuth::ChallengeTokenizer* challenge) { - return ParseChallenge(challenge, false); -} - // The NTLM challenge header looks like: // WWW-Authenticate: NTLM auth-data HttpAuth::AuthorizationResult HttpAuthHandlerNTLM::ParseChallenge( diff --git a/net/http/http_auth_handler_ntlm.h b/net/http/http_auth_handler_ntlm.h index 831e43d..ae7c78b 100644 --- a/net/http/http_auth_handler_ntlm.h +++ b/net/http/http_auth_handler_ntlm.h @@ -114,6 +114,10 @@ class HttpAuthHandlerNTLM : public HttpAuthHandler { HttpAuth::ChallengeTokenizer* challenge); protected: + // This function acquires a credentials handle in the SSPI implementation. + // It does nothing in the portable implementation. + int InitializeBeforeFirstChallenge(); + virtual bool Init(HttpAuth::ChallengeTokenizer* tok); virtual int GenerateAuthTokenImpl(const string16* username, @@ -122,10 +126,6 @@ class HttpAuthHandlerNTLM : public HttpAuthHandler { CompletionCallback* callback, std::string* auth_token); - // This function acquires a credentials handle in the SSPI implementation. - // It does nothing in the portable implementation. - int InitializeBeforeFirstChallenge(); - private: ~HttpAuthHandlerNTLM(); diff --git a/net/http/http_auth_handler_ntlm_portable.cc b/net/http/http_auth_handler_ntlm_portable.cc index d3abc98..2b06e58 100644 --- a/net/http/http_auth_handler_ntlm_portable.cc +++ b/net/http/http_auth_handler_ntlm_portable.cc @@ -643,12 +643,6 @@ HttpAuthHandlerNTLM::get_host_name_proc_ = GetHostName; HttpAuthHandlerNTLM::HttpAuthHandlerNTLM() { } -HttpAuthHandlerNTLM::~HttpAuthHandlerNTLM() { - // Wipe our copy of the password from memory, to reduce the chance of being - // written to the paging file on disk. - ZapString(&password_); -} - bool HttpAuthHandlerNTLM::NeedsIdentity() { return !auth_data_.empty(); } @@ -659,6 +653,16 @@ bool HttpAuthHandlerNTLM::AllowsDefaultCredentials() { return false; } +int HttpAuthHandlerNTLM::InitializeBeforeFirstChallenge() { + return OK; +} + +HttpAuthHandlerNTLM::~HttpAuthHandlerNTLM() { + // Wipe our copy of the password from memory, to reduce the chance of being + // written to the paging file on disk. + ZapString(&password_); +} + // static HttpAuthHandlerNTLM::GenerateRandomProc HttpAuthHandlerNTLM::SetGenerateRandomProc( @@ -676,6 +680,12 @@ HttpAuthHandlerNTLM::HostNameProc HttpAuthHandlerNTLM::SetHostNameProc( return old_proc; } +HttpAuthHandlerNTLM::Factory::Factory() { +} + +HttpAuthHandlerNTLM::Factory::~Factory() { +} + int HttpAuthHandlerNTLM::GetNextToken(const void* in_token, uint32 in_token_len, void** out_token, @@ -702,16 +712,6 @@ int HttpAuthHandlerNTLM::GetNextToken(const void* in_token, return rv; } -int HttpAuthHandlerNTLM::InitializeBeforeFirstChallenge() { - return OK; -} - -HttpAuthHandlerNTLM::Factory::Factory() { -} - -HttpAuthHandlerNTLM::Factory::~Factory() { -} - int HttpAuthHandlerNTLM::Factory::CreateAuthHandler( HttpAuth::ChallengeTokenizer* challenge, HttpAuth::Target target, diff --git a/net/http/http_basic_stream.cc b/net/http/http_basic_stream.cc index 061bb30..3e69d7a 100644 --- a/net/http/http_basic_stream.cc +++ b/net/http/http_basic_stream.cc @@ -25,6 +25,8 @@ HttpBasicStream::HttpBasicStream(ClientSocketHandle* connection, request_info_(NULL) { } +HttpBasicStream::~HttpBasicStream() {} + int HttpBasicStream::InitializeStream(const HttpRequestInfo* request_info, const BoundNetLog& net_log, CompletionCallback* callback) { @@ -52,8 +54,6 @@ int HttpBasicStream::SendRequest(const HttpRequestHeaders& headers, callback); } -HttpBasicStream::~HttpBasicStream() {} - uint64 HttpBasicStream::GetUploadProgress() const { return parser_->GetUploadProgress(); } diff --git a/net/http/http_cache_transaction.cc b/net/http/http_cache_transaction.cc index b506edc..32386f0 100644 --- a/net/http/http_cache_transaction.cc +++ b/net/http/http_cache_transaction.cc @@ -163,6 +163,70 @@ HttpCache::Transaction::~Transaction() { cache_.reset(); } +int HttpCache::Transaction::WriteMetadata(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + DCHECK(buf); + DCHECK_GT(buf_len, 0); + DCHECK(callback); + if (!cache_ || !entry_) + return ERR_UNEXPECTED; + + // We don't need to track this operation for anything. + // It could be possible to check if there is something already written and + // avoid writing again (it should be the same, right?), but let's allow the + // caller to "update" the contents with something new. + return entry_->disk_entry->WriteData(kMetadataIndex, 0, buf, buf_len, + callback, true); +} + +// Histogram data from the end of 2010 show the following distribution of +// response headers: +// +// Content-Length............... 87% +// Date......................... 98% +// Last-Modified................ 49% +// Etag......................... 19% +// Accept-Ranges: bytes......... 25% +// Accept-Ranges: none.......... 0.4% +// Strong Validator............. 50% +// Strong Validator + ranges.... 24% +// Strong Validator + CL........ 49% +// +bool HttpCache::Transaction::AddTruncatedFlag() { + DCHECK(mode_ & WRITE); + + // Don't set the flag for sparse entries. + if (partial_.get() && !truncated_) + return true; + + // Double check that there is something worth keeping. + if (!entry_->disk_entry->GetDataSize(kResponseContentIndex)) + return false; + + if (response_.headers->GetContentLength() <= 0 || + response_.headers->HasHeaderValue("Accept-Ranges", "none") || + !response_.headers->HasStrongValidators()) + return false; + + truncated_ = true; + target_state_ = STATE_NONE; + next_state_ = STATE_CACHE_WRITE_TRUNCATED_RESPONSE; + DoLoop(OK); + return true; +} + +LoadState HttpCache::Transaction::GetWriterLoadState() const { + if (network_trans_.get()) + return network_trans_->GetLoadState(); + if (entry_ || !request_) + return LOAD_STATE_IDLE; + return LOAD_STATE_WAITING_FOR_CACHE; +} + +const BoundNetLog& HttpCache::Transaction::net_log() const { + return net_log_; +} + int HttpCache::Transaction::Start(const HttpRequestInfo* request, CompletionCallback* callback, const BoundNetLog& net_log) { @@ -338,70 +402,6 @@ uint64 HttpCache::Transaction::GetUploadProgress() const { return final_upload_progress_; } -int HttpCache::Transaction::WriteMetadata(IOBuffer* buf, int buf_len, - CompletionCallback* callback) { - DCHECK(buf); - DCHECK_GT(buf_len, 0); - DCHECK(callback); - if (!cache_ || !entry_) - return ERR_UNEXPECTED; - - // We don't need to track this operation for anything. - // It could be possible to check if there is something already written and - // avoid writing again (it should be the same, right?), but let's allow the - // caller to "update" the contents with something new. - return entry_->disk_entry->WriteData(kMetadataIndex, 0, buf, buf_len, - callback, true); -} - -// Histogram data from the end of 2010 show the following distribution of -// response headers: -// -// Content-Length............... 87% -// Date......................... 98% -// Last-Modified................ 49% -// Etag......................... 19% -// Accept-Ranges: bytes......... 25% -// Accept-Ranges: none.......... 0.4% -// Strong Validator............. 50% -// Strong Validator + ranges.... 24% -// Strong Validator + CL........ 49% -// -bool HttpCache::Transaction::AddTruncatedFlag() { - DCHECK(mode_ & WRITE); - - // Don't set the flag for sparse entries. - if (partial_.get() && !truncated_) - return true; - - // Double check that there is something worth keeping. - if (!entry_->disk_entry->GetDataSize(kResponseContentIndex)) - return false; - - if (response_.headers->GetContentLength() <= 0 || - response_.headers->HasHeaderValue("Accept-Ranges", "none") || - !response_.headers->HasStrongValidators()) - return false; - - truncated_ = true; - target_state_ = STATE_NONE; - next_state_ = STATE_CACHE_WRITE_TRUNCATED_RESPONSE; - DoLoop(OK); - return true; -} - -LoadState HttpCache::Transaction::GetWriterLoadState() const { - if (network_trans_.get()) - return network_trans_->GetLoadState(); - if (entry_ || !request_) - return LOAD_STATE_IDLE; - return LOAD_STATE_WAITING_FOR_CACHE; -} - -const BoundNetLog& HttpCache::Transaction::net_log() const { - return net_log_; -} - //----------------------------------------------------------------------------- void HttpCache::Transaction::DoCallback(int rv) { diff --git a/net/http/http_cache_transaction.h b/net/http/http_cache_transaction.h index 316c15b..81160d5 100644 --- a/net/http/http_cache_transaction.h +++ b/net/http/http_cache_transaction.h @@ -28,25 +28,6 @@ struct HttpRequestInfo; // factory. class HttpCache::Transaction : public HttpTransaction { public: - Transaction(HttpCache* cache); - virtual ~Transaction(); - - // HttpTransaction methods: - virtual int Start(const HttpRequestInfo*, CompletionCallback*, - const BoundNetLog&); - virtual int RestartIgnoringLastError(CompletionCallback* callback); - virtual int RestartWithCertificate(X509Certificate* client_cert, - CompletionCallback* callback); - virtual int RestartWithAuth(const string16& username, - const string16& password, - CompletionCallback* callback); - virtual bool IsReadyToRestartForAuth(); - virtual int Read(IOBuffer* buf, int buf_len, CompletionCallback* callback); - virtual void StopCaching(); - virtual const HttpResponseInfo* GetResponseInfo() const; - virtual LoadState GetLoadState() const; - virtual uint64 GetUploadProgress(void) const; - // The transaction has the following modes, which apply to how it may access // its cache entry. // @@ -76,6 +57,9 @@ class HttpCache::Transaction : public HttpTransaction { UPDATE = READ_META | WRITE, // READ_WRITE & ~READ_DATA }; + Transaction(HttpCache* cache); + virtual ~Transaction(); + Mode mode() const { return mode_; } const std::string& key() const { return cache_key_; } @@ -112,6 +96,22 @@ class HttpCache::Transaction : public HttpTransaction { const BoundNetLog& net_log() const; + // HttpTransaction methods: + virtual int Start(const HttpRequestInfo*, CompletionCallback*, + const BoundNetLog&); + virtual int RestartIgnoringLastError(CompletionCallback* callback); + virtual int RestartWithCertificate(X509Certificate* client_cert, + CompletionCallback* callback); + virtual int RestartWithAuth(const string16& username, + const string16& password, + CompletionCallback* callback); + virtual bool IsReadyToRestartForAuth(); + virtual int Read(IOBuffer* buf, int buf_len, CompletionCallback* callback); + virtual void StopCaching(); + virtual const HttpResponseInfo* GetResponseInfo() const; + virtual LoadState GetLoadState() const; + virtual uint64 GetUploadProgress(void) const; + private: static const size_t kNumValidationHeaders = 2; // Helper struct to pair a header name with its value, for diff --git a/net/http/http_network_layer.cc b/net/http/http_network_layer.cc index 9a11034..975e75c 100644 --- a/net/http/http_network_layer.cc +++ b/net/http/http_network_layer.cc @@ -19,38 +19,6 @@ namespace net { //----------------------------------------------------------------------------- -// static -HttpTransactionFactory* HttpNetworkLayer::CreateFactory( - HostResolver* host_resolver, - CertVerifier* cert_verifier, - DnsRRResolver* dnsrr_resolver, - DnsCertProvenanceChecker* dns_cert_checker, - SSLHostInfoFactory* ssl_host_info_factory, - ProxyService* proxy_service, - SSLConfigService* ssl_config_service, - HttpAuthHandlerFactory* http_auth_handler_factory, - HttpNetworkDelegate* network_delegate, - NetLog* net_log) { - DCHECK(proxy_service); - - return new HttpNetworkLayer(ClientSocketFactory::GetDefaultFactory(), - host_resolver, cert_verifier, dnsrr_resolver, - dns_cert_checker, - ssl_host_info_factory, proxy_service, - ssl_config_service, http_auth_handler_factory, - network_delegate, - net_log); -} - -// static -HttpTransactionFactory* HttpNetworkLayer::CreateFactory( - HttpNetworkSession* session) { - DCHECK(session); - - return new HttpNetworkLayer(session); -} - -//----------------------------------------------------------------------------- HttpNetworkLayer::HttpNetworkLayer( ClientSocketFactory* socket_factory, HostResolver* host_resolver, @@ -132,56 +100,37 @@ HttpNetworkLayer::HttpNetworkLayer(HttpNetworkSession* session) HttpNetworkLayer::~HttpNetworkLayer() { } -int HttpNetworkLayer::CreateTransaction(scoped_ptr<HttpTransaction>* trans) { - if (suspended_) - return ERR_NETWORK_IO_SUSPENDED; +//----------------------------------------------------------------------------- - trans->reset(new HttpNetworkTransaction(GetSession())); - return OK; -} +// static +HttpTransactionFactory* HttpNetworkLayer::CreateFactory( + HostResolver* host_resolver, + CertVerifier* cert_verifier, + DnsRRResolver* dnsrr_resolver, + DnsCertProvenanceChecker* dns_cert_checker, + SSLHostInfoFactory* ssl_host_info_factory, + ProxyService* proxy_service, + SSLConfigService* ssl_config_service, + HttpAuthHandlerFactory* http_auth_handler_factory, + HttpNetworkDelegate* network_delegate, + NetLog* net_log) { + DCHECK(proxy_service); -HttpCache* HttpNetworkLayer::GetCache() { - return NULL; + return new HttpNetworkLayer(ClientSocketFactory::GetDefaultFactory(), + host_resolver, cert_verifier, dnsrr_resolver, + dns_cert_checker, + ssl_host_info_factory, proxy_service, + ssl_config_service, http_auth_handler_factory, + network_delegate, + net_log); } -void HttpNetworkLayer::Suspend(bool suspend) { - suspended_ = suspend; - - if (suspend && session_) - session_->tcp_socket_pool()->CloseIdleSockets(); -} +// static +HttpTransactionFactory* HttpNetworkLayer::CreateFactory( + HttpNetworkSession* session) { + DCHECK(session); -HttpNetworkSession* HttpNetworkLayer::GetSession() { - if (!session_) { - DCHECK(proxy_service_); - if (!spdy_session_pool_.get()) - spdy_session_pool_.reset(new SpdySessionPool(ssl_config_service_)); - session_ = new HttpNetworkSession( - host_resolver_, - cert_verifier_, - dnsrr_resolver_, - dns_cert_checker_, - ssl_host_info_factory_, - proxy_service_, - socket_factory_, - ssl_config_service_, - spdy_session_pool_.release(), - http_auth_handler_factory_, - network_delegate_, - net_log_); - // These were just temps for lazy-initializing HttpNetworkSession. - host_resolver_ = NULL; - cert_verifier_ = NULL; - dnsrr_resolver_ = NULL; - dns_cert_checker_ = NULL; - ssl_host_info_factory_ = NULL; - proxy_service_ = NULL; - socket_factory_ = NULL; - http_auth_handler_factory_ = NULL; - net_log_ = NULL; - network_delegate_ = NULL; - } - return session_; + return new HttpNetworkLayer(session); } // static @@ -277,4 +226,59 @@ void HttpNetworkLayer::EnableSpdy(const std::string& mode) { } } } + +//----------------------------------------------------------------------------- + +int HttpNetworkLayer::CreateTransaction(scoped_ptr<HttpTransaction>* trans) { + if (suspended_) + return ERR_NETWORK_IO_SUSPENDED; + + trans->reset(new HttpNetworkTransaction(GetSession())); + return OK; +} + +HttpCache* HttpNetworkLayer::GetCache() { + return NULL; +} + +HttpNetworkSession* HttpNetworkLayer::GetSession() { + if (!session_) { + DCHECK(proxy_service_); + if (!spdy_session_pool_.get()) + spdy_session_pool_.reset(new SpdySessionPool(ssl_config_service_)); + session_ = new HttpNetworkSession( + host_resolver_, + cert_verifier_, + dnsrr_resolver_, + dns_cert_checker_, + ssl_host_info_factory_, + proxy_service_, + socket_factory_, + ssl_config_service_, + spdy_session_pool_.release(), + http_auth_handler_factory_, + network_delegate_, + net_log_); + // These were just temps for lazy-initializing HttpNetworkSession. + host_resolver_ = NULL; + cert_verifier_ = NULL; + dnsrr_resolver_ = NULL; + dns_cert_checker_ = NULL; + ssl_host_info_factory_ = NULL; + proxy_service_ = NULL; + socket_factory_ = NULL; + http_auth_handler_factory_ = NULL; + net_log_ = NULL; + network_delegate_ = NULL; + } + return session_; +} + +void HttpNetworkLayer::Suspend(bool suspend) { + suspended_ = suspend; + + if (suspend && session_) + session_->tcp_socket_pool()->CloseIdleSockets(); +} + } // namespace net diff --git a/net/http/http_network_layer.h b/net/http/http_network_layer.h index 730b5c7..963ebee 100644 --- a/net/http/http_network_layer.h +++ b/net/http/http_network_layer.h @@ -78,6 +78,7 @@ class HttpNetworkLayer : public HttpTransactionFactory, HttpAuthHandlerFactory* http_auth_handler_factory, HttpNetworkDelegate* network_delegate, NetLog* net_log); + // Create a transaction factory that instantiate a network layer over an // existing network session. Network session contains some valuable // information (e.g. authentication data) that we want to share across @@ -86,12 +87,6 @@ class HttpNetworkLayer : public HttpTransactionFactory, // when network session is shared. static HttpTransactionFactory* CreateFactory(HttpNetworkSession* session); - // HttpTransactionFactory methods: - virtual int CreateTransaction(scoped_ptr<HttpTransaction>* trans); - virtual HttpCache* GetCache(); - virtual HttpNetworkSession* GetSession(); - virtual void Suspend(bool suspend); - // Enable the spdy protocol. // Without calling this function, SPDY is disabled. The mode can be: // "" : (default) SSL and compression are enabled, flow @@ -102,6 +97,12 @@ class HttpNetworkLayer : public HttpTransactionFactory, // "none" : disables both SSL and compression. static void EnableSpdy(const std::string& mode); + // HttpTransactionFactory methods: + virtual int CreateTransaction(scoped_ptr<HttpTransaction>* trans); + virtual HttpCache* GetCache(); + virtual HttpNetworkSession* GetSession(); + virtual void Suspend(bool suspend); + private: // The factory we will use to create network sockets. ClientSocketFactory* socket_factory_; diff --git a/net/http/http_request_headers.cc b/net/http/http_request_headers.cc index 9ce77bf..9d523c1 100644 --- a/net/http/http_request_headers.cc +++ b/net/http/http_request_headers.cc @@ -77,13 +77,6 @@ void HttpRequestHeaders::Clear() { headers_.clear(); } -void HttpRequestHeaders::SetHeaderIfMissing(const base::StringPiece& key, - const base::StringPiece& value) { - HeaderVector::iterator it = FindHeader(key); - if (it == headers_.end()) - headers_.push_back(HeaderKeyValuePair(key.as_string(), value.as_string())); -} - void HttpRequestHeaders::SetHeader(const base::StringPiece& key, const base::StringPiece& value) { HeaderVector::iterator it = FindHeader(key); @@ -93,6 +86,13 @@ void HttpRequestHeaders::SetHeader(const base::StringPiece& key, headers_.push_back(HeaderKeyValuePair(key.as_string(), value.as_string())); } +void HttpRequestHeaders::SetHeaderIfMissing(const base::StringPiece& key, + const base::StringPiece& value) { + HeaderVector::iterator it = FindHeader(key); + if (it == headers_.end()) + headers_.push_back(HeaderKeyValuePair(key.as_string(), value.as_string())); +} + void HttpRequestHeaders::RemoveHeader(const base::StringPiece& key) { HeaderVector::iterator it = FindHeader(key); if (it != headers_.end()) diff --git a/net/http/mock_gssapi_library_posix.cc b/net/http/mock_gssapi_library_posix.cc index ec69964..1ca5040 100644 --- a/net/http/mock_gssapi_library_posix.cc +++ b/net/http/mock_gssapi_library_posix.cc @@ -194,6 +194,23 @@ MockGSSAPILibrary::MockGSSAPILibrary() { MockGSSAPILibrary::~MockGSSAPILibrary() { } +void MockGSSAPILibrary::ExpectSecurityContext( + const std::string& expected_package, + OM_uint32 response_code, + OM_uint32 minor_response_code, + const GssContextMockImpl& context_info, + const gss_buffer_desc& expected_input_token, + const gss_buffer_desc& output_token) { + SecurityContextQuery security_query; + security_query.expected_package = expected_package; + security_query.response_code = response_code; + security_query.minor_response_code = minor_response_code; + security_query.context_info.Assign(context_info); + security_query.expected_input_token = expected_input_token; + security_query.output_token = output_token; + expected_security_queries_.push_back(security_query); +} + bool MockGSSAPILibrary::Init() { return true; } @@ -417,23 +434,6 @@ OM_uint32 MockGSSAPILibrary::inquire_context( return GSS_S_COMPLETE; } -void MockGSSAPILibrary::ExpectSecurityContext( - const std::string& expected_package, - OM_uint32 response_code, - OM_uint32 minor_response_code, - const GssContextMockImpl& context_info, - const gss_buffer_desc& expected_input_token, - const gss_buffer_desc& output_token) { - SecurityContextQuery security_query; - security_query.expected_package = expected_package; - security_query.response_code = response_code; - security_query.minor_response_code = minor_response_code; - security_query.context_info.Assign(context_info); - security_query.expected_input_token = expected_input_token; - security_query.output_token = output_token; - expected_security_queries_.push_back(security_query); -} - } // namespace test } // namespace net diff --git a/net/http/mock_gssapi_library_posix.h b/net/http/mock_gssapi_library_posix.h index 15e14f2..f0652d3 100644 --- a/net/http/mock_gssapi_library_posix.h +++ b/net/http/mock_gssapi_library_posix.h @@ -45,10 +45,61 @@ class GssContextMockImpl { // the system GSSAPI library calls. class MockGSSAPILibrary : public GSSAPILibrary { public: + // Unit tests need access to this. "Friend"ing didn't help. + struct SecurityContextQuery { + std::string expected_package; + OM_uint32 response_code; + OM_uint32 minor_response_code; + test::GssContextMockImpl context_info; + gss_buffer_desc expected_input_token; + gss_buffer_desc output_token; + }; MockGSSAPILibrary(); virtual ~MockGSSAPILibrary(); + // Establishes an expectation for a |init_sec_context()| call. + // + // Each expectation established by |ExpectSecurityContext()| must be + // matched by a call to |init_sec_context()| during the lifetime of + // the MockGSSAPILibrary. The |expected_package| argument must equal the + // value associated with the |target_name| argument to |init_sec_context()| + // for there to be a match. The expectations also establish an explicit + // ordering. + // + // For example, this sequence will be successful. + // MockGSSAPILibrary lib; + // lib.ExpectSecurityContext("NTLM", ...) + // lib.ExpectSecurityContext("Negotiate", ...) + // lib.init_sec_context("NTLM", ...) + // lib.init_sec_context("Negotiate", ...) + // + // This sequence will fail since the queries do not occur in the order + // established by the expectations. + // MockGSSAPILibrary lib; + // lib.ExpectSecurityContext("NTLM", ...) + // lib.ExpectSecurityContext("Negotiate", ...) + // lib.init_sec_context("Negotiate", ...) + // lib.init_sec_context("NTLM", ...) + // + // This sequence will fail because there were not enough queries. + // MockGSSAPILibrary lib; + // lib.ExpectSecurityContext("NTLM", ...) + // lib.ExpectSecurityContext("Negotiate", ...) + // lib.init_sec_context("NTLM", ...) + // + // |response_code| is used as the return value for |init_sec_context()|. + // If |response_code| is GSS_S_COMPLETE, + // + // |context_info| is the expected value of the |**context_handle| in after + // |init_sec_context()| returns. + void ExpectSecurityContext(const std::string& expected_package, + OM_uint32 response_code, + OM_uint32 minor_response_code, + const test::GssContextMockImpl& context_info, + const gss_buffer_desc& expected_input_token, + const gss_buffer_desc& output_token); + // GSSAPILibrary methods: // Initializes the library, including any necessary dynamic libraries. @@ -116,58 +167,6 @@ class MockGSSAPILibrary : public GSSAPILibrary { int* locally_initiated, int* open); - // Establishes an expectation for a |init_sec_context()| call. - // - // Each expectation established by |ExpectSecurityContext()| must be - // matched by a call to |init_sec_context()| during the lifetime of - // the MockGSSAPILibrary. The |expected_package| argument must equal the - // value associated with the |target_name| argument to |init_sec_context()| - // for there to be a match. The expectations also establish an explicit - // ordering. - // - // For example, this sequence will be successful. - // MockGSSAPILibrary lib; - // lib.ExpectSecurityContext("NTLM", ...) - // lib.ExpectSecurityContext("Negotiate", ...) - // lib.init_sec_context("NTLM", ...) - // lib.init_sec_context("Negotiate", ...) - // - // This sequence will fail since the queries do not occur in the order - // established by the expectations. - // MockGSSAPILibrary lib; - // lib.ExpectSecurityContext("NTLM", ...) - // lib.ExpectSecurityContext("Negotiate", ...) - // lib.init_sec_context("Negotiate", ...) - // lib.init_sec_context("NTLM", ...) - // - // This sequence will fail because there were not enough queries. - // MockGSSAPILibrary lib; - // lib.ExpectSecurityContext("NTLM", ...) - // lib.ExpectSecurityContext("Negotiate", ...) - // lib.init_sec_context("NTLM", ...) - // - // |response_code| is used as the return value for |init_sec_context()|. - // If |response_code| is GSS_S_COMPLETE, - // - // |context_info| is the expected value of the |**context_handle| in after - // |init_sec_context()| returns. - void ExpectSecurityContext(const std::string& expected_package, - OM_uint32 response_code, - OM_uint32 minor_response_code, - const test::GssContextMockImpl& context_info, - const gss_buffer_desc& expected_input_token, - const gss_buffer_desc& output_token); - - // Unit tests need access to this. "Friend"ing didn't help. - struct SecurityContextQuery { - std::string expected_package; - OM_uint32 response_code; - OM_uint32 minor_response_code; - test::GssContextMockImpl context_info; - gss_buffer_desc expected_input_token; - gss_buffer_desc output_token; - }; - private: FRIEND_TEST_ALL_PREFIXES(HttpAuthGSSAPIPOSIXTest, GSSAPICycle); diff --git a/net/proxy/proxy_list.cc b/net/proxy/proxy_list.cc index 236e707..011aab9 100644 --- a/net/proxy/proxy_list.cc +++ b/net/proxy/proxy_list.cc @@ -86,17 +86,6 @@ const ProxyServer& ProxyList::Get() const { return proxies_[0]; } -std::string ProxyList::ToPacString() const { - std::string proxy_list; - std::vector<ProxyServer>::const_iterator iter = proxies_.begin(); - for (; iter != proxies_.end(); ++iter) { - if (!proxy_list.empty()) - proxy_list += ";"; - proxy_list += iter->ToPacString(); - } - return proxy_list.empty() ? std::string() : proxy_list; -} - void ProxyList::SetFromPacString(const std::string& pac_string) { StringTokenizer entry_tok(pac_string, ";"); proxies_.clear(); @@ -115,6 +104,17 @@ void ProxyList::SetFromPacString(const std::string& pac_string) { } } +std::string ProxyList::ToPacString() const { + std::string proxy_list; + std::vector<ProxyServer>::const_iterator iter = proxies_.begin(); + for (; iter != proxies_.end(); ++iter) { + if (!proxy_list.empty()) + proxy_list += ";"; + proxy_list += iter->ToPacString(); + } + return proxy_list.empty() ? std::string() : proxy_list; +} + bool ProxyList::Fallback(ProxyRetryInfoMap* proxy_retry_info) { // Number of minutes to wait before retrying a bad proxy server. const TimeDelta kProxyRetryDelay = TimeDelta::FromMinutes(5); diff --git a/net/socket/client_socket.cc b/net/socket/client_socket.cc index 6b12841..3792c5c 100644 --- a/net/socket/client_socket.cc +++ b/net/socket/client_socket.cc @@ -66,42 +66,6 @@ void ClientSocket::UseHistory::Reset() { // are intentionally preserved. } -void ClientSocket::UseHistory::EmitPreconnectionHistograms() const { - DCHECK(!subresource_speculation_ || !omnibox_speculation_); - // 0 ==> non-speculative, never connected. - // 1 ==> non-speculative never used (but connected). - // 2 ==> non-speculative and used. - // 3 ==> omnibox_speculative never connected. - // 4 ==> omnibox_speculative never used (but connected). - // 5 ==> omnibox_speculative and used. - // 6 ==> subresource_speculative never connected. - // 7 ==> subresource_speculative never used (but connected). - // 8 ==> subresource_speculative and used. - int result; - if (was_used_to_convey_data_) - result = 2; - else if (was_ever_connected_) - result = 1; - else - result = 0; // Never used, and not really connected. - - if (omnibox_speculation_) - result += 3; - else if (subresource_speculation_) - result += 6; - UMA_HISTOGRAM_ENUMERATION("Net.PreconnectUtilization2", result, 9); - - static const bool connect_backup_jobs_fieldtrial = - base::FieldTrialList::Find("ConnnectBackupJobs") && - !base::FieldTrialList::Find("ConnnectBackupJobs")->group_name().empty(); - if (connect_backup_jobs_fieldtrial) { - UMA_HISTOGRAM_ENUMERATION( - base::FieldTrial::MakeName("Net.PreconnectUtilization2", - "ConnnectBackupJobs"), - result, 9); - } -} - void ClientSocket::UseHistory::set_was_ever_connected() { DCHECK(!was_used_to_convey_data_); was_ever_connected_ = true; @@ -144,6 +108,42 @@ bool ClientSocket::UseHistory::was_used_to_convey_data() const { return was_used_to_convey_data_; } +void ClientSocket::UseHistory::EmitPreconnectionHistograms() const { + DCHECK(!subresource_speculation_ || !omnibox_speculation_); + // 0 ==> non-speculative, never connected. + // 1 ==> non-speculative never used (but connected). + // 2 ==> non-speculative and used. + // 3 ==> omnibox_speculative never connected. + // 4 ==> omnibox_speculative never used (but connected). + // 5 ==> omnibox_speculative and used. + // 6 ==> subresource_speculative never connected. + // 7 ==> subresource_speculative never used (but connected). + // 8 ==> subresource_speculative and used. + int result; + if (was_used_to_convey_data_) + result = 2; + else if (was_ever_connected_) + result = 1; + else + result = 0; // Never used, and not really connected. + + if (omnibox_speculation_) + result += 3; + else if (subresource_speculation_) + result += 6; + UMA_HISTOGRAM_ENUMERATION("Net.PreconnectUtilization2", result, 9); + + static const bool connect_backup_jobs_fieldtrial = + base::FieldTrialList::Find("ConnnectBackupJobs") && + !base::FieldTrialList::Find("ConnnectBackupJobs")->group_name().empty(); + if (connect_backup_jobs_fieldtrial) { + UMA_HISTOGRAM_ENUMERATION( + base::FieldTrial::MakeName("Net.PreconnectUtilization2", + "ConnnectBackupJobs"), + result, 9); + } +} + void ClientSocket::LogByteTransfer(const BoundNetLog& net_log, NetLog::EventType event_type, int byte_count, diff --git a/net/socket/client_socket_factory.cc b/net/socket/client_socket_factory.cc index f4da066..dd201f9 100644 --- a/net/socket/client_socket_factory.cc +++ b/net/socket/client_socket_factory.cc @@ -79,17 +79,6 @@ static base::LazyInstance<DefaultClientSocketFactory> } // namespace -// static -ClientSocketFactory* ClientSocketFactory::GetDefaultFactory() { - return g_default_client_socket_factory.Pointer(); -} - -// static -void ClientSocketFactory::SetSSLClientSocketFactory( - SSLClientSocketFactory factory) { - g_ssl_factory = factory; -} - // Deprecated function (http://crbug.com/37810) that takes a ClientSocket. SSLClientSocket* ClientSocketFactory::CreateSSLClientSocket( ClientSocket* transport_socket, @@ -104,4 +93,15 @@ SSLClientSocket* ClientSocketFactory::CreateSSLClientSocket( NULL /* DnsCertProvenanceChecker */); } +// static +ClientSocketFactory* ClientSocketFactory::GetDefaultFactory() { + return g_default_client_socket_factory.Pointer(); +} + +// static +void ClientSocketFactory::SetSSLClientSocketFactory( + SSLClientSocketFactory factory) { + g_ssl_factory = factory; +} + } // namespace net diff --git a/net/socket/dns_cert_provenance_checker.cc b/net/socket/dns_cert_provenance_checker.cc index 665a16a..33487c5 100644 --- a/net/socket/dns_cert_provenance_checker.cc +++ b/net/socket/dns_cert_provenance_checker.cc @@ -218,6 +218,22 @@ SECKEYPublicKey* GetServerPubKey() { } // namespace +DnsCertProvenanceChecker::Delegate::~Delegate() { +} + +DnsCertProvenanceChecker::~DnsCertProvenanceChecker() { +} + +void DnsCertProvenanceChecker::DoAsyncLookup( + const std::string& hostname, + const std::vector<base::StringPiece>& der_certs, + DnsRRResolver* dnsrr_resolver, + Delegate* delegate) { + DnsCertProvenanceCheck* check = new DnsCertProvenanceCheck( + hostname, dnsrr_resolver, delegate, der_certs); + check->Start(); +} + // static std::string DnsCertProvenanceChecker::BuildEncryptedReport( const std::string& hostname, @@ -318,32 +334,16 @@ std::string DnsCertProvenanceChecker::BuildEncryptedReport( outer.size()); } -void DnsCertProvenanceChecker::DoAsyncLookup( - const std::string& hostname, - const std::vector<base::StringPiece>& der_certs, - DnsRRResolver* dnsrr_resolver, - Delegate* delegate) { - DnsCertProvenanceCheck* check = new DnsCertProvenanceCheck( - hostname, dnsrr_resolver, delegate, der_certs); - check->Start(); -} - -DnsCertProvenanceChecker::Delegate::~Delegate() { -} - -DnsCertProvenanceChecker::~DnsCertProvenanceChecker() { -} - } // namespace net #else // USE_OPENSSL namespace net { -std::string DnsCertProvenanceChecker::BuildEncryptedReport( - const std::string& hostname, - const std::vector<std::string>& der_certs) { - return ""; +DnsCertProvenanceChecker::Delegate::~Delegate() { +} + +DnsCertProvenanceChecker::~DnsCertProvenanceChecker() { } void DnsCertProvenanceChecker::DoAsyncLookup( @@ -353,10 +353,10 @@ void DnsCertProvenanceChecker::DoAsyncLookup( Delegate* delegate) { } -DnsCertProvenanceChecker::Delegate::~Delegate() { -} - -DnsCertProvenanceChecker::~DnsCertProvenanceChecker() { +std::string DnsCertProvenanceChecker::BuildEncryptedReport( + const std::string& hostname, + const std::vector<std::string>& der_certs) { + return ""; } } // namespace net diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc index d006d58..a71af2c 100644 --- a/net/socket/ssl_client_socket_nss.cc +++ b/net/socket/ssl_client_socket_nss.cc @@ -279,6 +279,162 @@ class PeerCertificateChain { CERTCertificate** certs_; }; +void DestroyCertificates(CERTCertificate** certs, unsigned len) { + for (unsigned i = 0; i < len; i++) + CERT_DestroyCertificate(certs[i]); +} + +// DNSValidationResult enumerates the possible outcomes from processing a +// set of DNS records. +enum DNSValidationResult { + DNSVR_SUCCESS, // the cert is immediately acceptable. + DNSVR_FAILURE, // the cert is unconditionally rejected. + DNSVR_CONTINUE, // perform CA validation as usual. +}; + +// VerifyTXTRecords processes the RRDATA for a number of DNS TXT records and +// checks them against the given certificate. +// dnssec: if true then the TXT records are DNSSEC validated. In this case, +// DNSVR_SUCCESS may be returned. +// server_cert_nss: the certificate to validate +// rrdatas: the TXT records for the current domain. +DNSValidationResult VerifyTXTRecords( + bool dnssec, + CERTCertificate* server_cert_nss, + const std::vector<base::StringPiece>& rrdatas) { + bool found_well_formed_record = false; + bool matched_record = false; + + for (std::vector<base::StringPiece>::const_iterator + i = rrdatas.begin(); i != rrdatas.end(); ++i) { + std::map<std::string, std::string> m( + DNSSECChainVerifier::ParseTLSTXTRecord(*i)); + if (m.empty()) + continue; + + std::map<std::string, std::string>::const_iterator j; + j = m.find("v"); + if (j == m.end() || j->second != "tls1") + continue; + + j = m.find("ha"); + + HASH_HashType hash_algorithm; + unsigned hash_length; + if (j == m.end() || j->second == "sha1") { + hash_algorithm = HASH_AlgSHA1; + hash_length = SHA1_LENGTH; + } else if (j->second == "sha256") { + hash_algorithm = HASH_AlgSHA256; + hash_length = SHA256_LENGTH; + } else { + continue; + } + + j = m.find("h"); + if (j == m.end()) + continue; + + std::vector<uint8> given_hash; + if (!base::HexStringToBytes(j->second, &given_hash)) + continue; + + if (given_hash.size() != hash_length) + continue; + + uint8 calculated_hash[SHA256_LENGTH]; // SHA256 is the largest. + SECStatus rv; + + j = m.find("hr"); + if (j == m.end() || j->second == "pubkey") { + rv = HASH_HashBuf(hash_algorithm, calculated_hash, + server_cert_nss->derPublicKey.data, + server_cert_nss->derPublicKey.len); + } else if (j->second == "cert") { + rv = HASH_HashBuf(hash_algorithm, calculated_hash, + server_cert_nss->derCert.data, + server_cert_nss->derCert.len); + } else { + continue; + } + + if (rv != SECSuccess) + NOTREACHED(); + + found_well_formed_record = true; + + if (memcmp(calculated_hash, &given_hash[0], hash_length) == 0) { + matched_record = true; + if (dnssec) + return DNSVR_SUCCESS; + } + } + + if (found_well_formed_record && !matched_record) + return DNSVR_FAILURE; + + return DNSVR_CONTINUE; +} + +// CheckDNSSECChain tries to validate a DNSSEC chain embedded in +// |server_cert_nss_|. It returns true iff a chain is found that proves the +// value of a TXT record that contains a valid public key fingerprint. +DNSValidationResult CheckDNSSECChain( + const std::string& hostname, + CERTCertificate* server_cert_nss) { + if (!server_cert_nss) + return DNSVR_CONTINUE; + + // CERT_FindCertExtensionByOID isn't exported so we have to install an OID, + // get a tag for it and find the extension by using that tag. + static SECOidTag dnssec_chain_tag; + static bool dnssec_chain_tag_valid; + if (!dnssec_chain_tag_valid) { + // It's harmless if multiple threads enter this block concurrently. + static const uint8 kDNSSECChainOID[] = + // 1.3.6.1.4.1.11129.2.1.4 + // (iso.org.dod.internet.private.enterprises.google.googleSecurity. + // certificateExtensions.dnssecEmbeddedChain) + {0x2b, 0x06, 0x01, 0x04, 0x01, 0xd6, 0x79, 0x02, 0x01, 0x04}; + SECOidData oid_data; + memset(&oid_data, 0, sizeof(oid_data)); + oid_data.oid.data = const_cast<uint8*>(kDNSSECChainOID); + oid_data.oid.len = sizeof(kDNSSECChainOID); + oid_data.desc = "DNSSEC chain"; + oid_data.supportedExtension = SUPPORTED_CERT_EXTENSION; + dnssec_chain_tag = SECOID_AddEntry(&oid_data); + DCHECK_NE(SEC_OID_UNKNOWN, dnssec_chain_tag); + dnssec_chain_tag_valid = true; + } + + SECItem dnssec_embedded_chain; + SECStatus rv = CERT_FindCertExtension(server_cert_nss, + dnssec_chain_tag, &dnssec_embedded_chain); + if (rv != SECSuccess) + return DNSVR_CONTINUE; + + base::StringPiece chain( + reinterpret_cast<char*>(dnssec_embedded_chain.data), + dnssec_embedded_chain.len); + std::string dns_hostname; + if (!DNSDomainFromDot(hostname, &dns_hostname)) + return DNSVR_CONTINUE; + DNSSECChainVerifier verifier(dns_hostname, chain); + DNSSECChainVerifier::Error err = verifier.Verify(); + if (err != DNSSECChainVerifier::OK) { + LOG(ERROR) << "DNSSEC chain verification failed: " << err; + return DNSVR_CONTINUE; + } + + if (verifier.rrtype() != kDNS_TXT) + return DNSVR_CONTINUE; + + DNSValidationResult r = VerifyTXTRecords( + true /* DNSSEC verified */, server_cert_nss, verifier.rrdatas()); + SECITEM_FreeItem(&dnssec_embedded_chain, PR_FALSE); + return r; +} + } // namespace SSLClientSocketNSS::SSLClientSocketNSS(ClientSocketHandle* transport_socket, @@ -333,156 +489,94 @@ SSLClientSocketNSS::~SSLClientSocketNSS() { LeaveFunction(""); } -int SSLClientSocketNSS::Init() { - EnterFunction(""); - // Initialize the NSS SSL library in a threadsafe way. This also - // initializes the NSS base library. - EnsureNSSSSLInit(); - if (!NSS_IsInitialized()) - return ERR_UNEXPECTED; -#if !defined(OS_MACOSX) && !defined(OS_WIN) - // We must call EnsureOCSPInit() here, on the IO thread, to get the IO loop - // by MessageLoopForIO::current(). - // X509Certificate::Verify() runs on a worker thread of CertVerifier. - EnsureOCSPInit(); -#endif - - LeaveFunction(""); - return OK; +// static +void SSLClientSocketNSS::ClearSessionCache() { + SSL_ClearSessionCache(); } -// SaveSnapStartInfo extracts the information needed to perform a Snap Start -// with this server in the future (if any) and tells |ssl_host_info_| to -// preserve it. -void SSLClientSocketNSS::SaveSnapStartInfo() { - if (!ssl_host_info_.get()) - return; - - // If the SSLHostInfo hasn't managed to load from disk yet then we can't save - // anything. - if (ssl_host_info_->WaitForDataReady(NULL) != OK) - return; +void SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { + EnterFunction(""); + ssl_info->Reset(); - SECStatus rv; - SSLSnapStartResult snap_start_type; - rv = SSL_GetSnapStartResult(nss_fd_, &snap_start_type); - if (rv != SECSuccess) { - NOTREACHED(); - return; - } - net_log_.AddEvent(NetLog::TYPE_SSL_SNAP_START, - new NetLogIntegerParameter("type", snap_start_type)); - if (snap_start_type == SSL_SNAP_START_FULL || - snap_start_type == SSL_SNAP_START_RESUME) { - // If we did a successful Snap Start then our information was correct and - // there's no point saving it again. + if (!server_cert_) { + LOG(DFATAL) << "!server_cert_"; return; } - const unsigned char* hello_data; - unsigned hello_data_len; - rv = SSL_GetPredictedServerHelloData(nss_fd_, &hello_data, &hello_data_len); - if (rv != SECSuccess) { - NOTREACHED(); - return; - } - if (hello_data_len > std::numeric_limits<uint16>::max()) - return; - SSLHostInfo::State* state = ssl_host_info_->mutable_state(); + ssl_info->cert_status = server_cert_verify_result_->cert_status; + DCHECK(server_cert_ != NULL); + ssl_info->cert = server_cert_; + ssl_info->connection_status = ssl_connection_status_; - if (hello_data_len > 0) { - state->server_hello = - std::string(reinterpret_cast<const char *>(hello_data), hello_data_len); - state->npn_valid = true; - state->npn_status = GetNextProto(&state->npn_protocol); + PRUint16 cipher_suite = + SSLConnectionStatusToCipherSuite(ssl_connection_status_); + SSLCipherSuiteInfo cipher_info; + SECStatus ok = SSL_GetCipherSuiteInfo(cipher_suite, + &cipher_info, sizeof(cipher_info)); + if (ok == SECSuccess) { + ssl_info->security_bits = cipher_info.effectiveKeyBits; } else { - state->server_hello.clear(); - state->npn_valid = false; - } - - state->certs.clear(); - PeerCertificateChain certs(nss_fd_); - for (unsigned i = 0; i < certs.size(); i++) { - if (certs[i]->derCert.len > std::numeric_limits<uint16>::max()) - return; - - state->certs.push_back(std::string( - reinterpret_cast<char*>(certs[i]->derCert.data), - certs[i]->derCert.len)); + ssl_info->security_bits = -1; + LOG(DFATAL) << "SSL_GetCipherSuiteInfo returned " << PR_GetError() + << " for cipherSuite " << cipher_suite; } - - ssl_host_info_->Persist(); + LeaveFunction(""); } -static void DestroyCertificates(CERTCertificate** certs, unsigned len) { - for (unsigned i = 0; i < len; i++) - CERT_DestroyCertificate(certs[i]); +void SSLClientSocketNSS::GetSSLCertRequestInfo( + SSLCertRequestInfo* cert_request_info) { + EnterFunction(""); + // TODO(rch): switch SSLCertRequestInfo.host_and_port to a HostPortPair + cert_request_info->host_and_port = host_and_port_.ToString(); + cert_request_info->client_certs = client_certs_; + LeaveFunction(cert_request_info->client_certs.size()); } -// LoadSnapStartInfo parses |info|, which contains data previously serialised -// by |SaveSnapStartInfo|, and sets the predicted certificates and ServerHello -// data on the NSS socket. Returns true on success. If this function returns -// false, the caller should try a normal TLS handshake. -bool SSLClientSocketNSS::LoadSnapStartInfo() { - const SSLHostInfo::State& state(ssl_host_info_->state()); - - if (state.server_hello.empty() || - state.certs.empty() || - !state.npn_valid) { - return false; +SSLClientSocket::NextProtoStatus +SSLClientSocketNSS::GetNextProto(std::string* proto) { +#if defined(SSL_NEXT_PROTO_NEGOTIATED) + if (!handshake_callback_called_) { + DCHECK(pseudo_connected_); + predicted_npn_proto_used_ = true; + *proto = predicted_npn_proto_; + return predicted_npn_status_; } - SECStatus rv; - rv = SSL_SetPredictedServerHelloData( - nss_fd_, - reinterpret_cast<const uint8*>(state.server_hello.data()), - state.server_hello.size()); - DCHECK_EQ(SECSuccess, rv); - - const std::vector<std::string>& certs_in = state.certs; - scoped_array<CERTCertificate*> certs(new CERTCertificate*[certs_in.size()]); - for (size_t i = 0; i < certs_in.size(); i++) { - SECItem derCert; - derCert.data = - const_cast<uint8*>(reinterpret_cast<const uint8*>(certs_in[i].data())); - derCert.len = certs_in[i].size(); - certs[i] = CERT_NewTempCertificate( - CERT_GetDefaultCertDB(), &derCert, NULL /* no nickname given */, - PR_FALSE /* not permanent */, PR_TRUE /* copy DER data */); - if (!certs[i]) { - DestroyCertificates(&certs[0], i); - NOTREACHED(); - return false; - } + unsigned char buf[255]; + int state; + unsigned len; + SECStatus rv = SSL_GetNextProto(nss_fd_, &state, buf, &len, sizeof(buf)); + if (rv != SECSuccess) { + NOTREACHED() << "Error return from SSL_GetNextProto: " << rv; + proto->clear(); + return kNextProtoUnsupported; } - - rv = SSL_SetPredictedPeerCertificates(nss_fd_, certs.get(), certs_in.size()); - DestroyCertificates(&certs[0], certs_in.size()); - DCHECK_EQ(SECSuccess, rv); - - if (state.npn_valid) { - predicted_npn_status_ = state.npn_status; - predicted_npn_proto_ = state.npn_protocol; + // We don't check for truncation because sizeof(buf) is large enough to hold + // the maximum protocol size. + switch (state) { + case SSL_NEXT_PROTO_NO_SUPPORT: + proto->clear(); + return kNextProtoUnsupported; + case SSL_NEXT_PROTO_NEGOTIATED: + *proto = std::string(reinterpret_cast<char*>(buf), len); + return kNextProtoNegotiated; + case SSL_NEXT_PROTO_NO_OVERLAP: + *proto = std::string(reinterpret_cast<char*>(buf), len); + return kNextProtoNoOverlap; + default: + NOTREACHED() << "Unknown status from SSL_GetNextProto: " << state; + proto->clear(); + return kNextProtoUnsupported; } - - return true; -} - -bool SSLClientSocketNSS::IsNPNProtocolMispredicted() { - DCHECK(handshake_callback_called_); - if (!predicted_npn_proto_used_) - return false; - std::string npn_proto; - GetNextProto(&npn_proto); - return predicted_npn_proto_ != npn_proto; +#else + // No NPN support in the libssl that we are building with. + proto->clear(); + return kNextProtoUnsupported; +#endif } -void SSLClientSocketNSS::UncorkAfterTimeout() { - corked_ = false; - int nsent; - do { - nsent = BufferSend(); - } while (nsent > 0); +void SSLClientSocketNSS::UseDNSSEC(DNSSECProvider* provider) { + dnssec_provider_ = provider; } int SSLClientSocketNSS::Connect(CompletionCallback* callback) { @@ -542,6 +636,228 @@ int SSLClientSocketNSS::Connect(CompletionCallback* callback) { return rv > OK ? OK : rv; } +void SSLClientSocketNSS::Disconnect() { + EnterFunction(""); + + // TODO(wtc): Send SSL close_notify alert. + if (nss_fd_ != NULL) { + PR_Close(nss_fd_); + nss_fd_ = NULL; + } + + // Shut down anything that may call us back (through buffer_send_callback_, + // buffer_recv_callback, or handshake_io_callback_). + verifier_.reset(); + transport_->socket()->Disconnect(); + + // Reset object state + transport_send_busy_ = false; + transport_recv_busy_ = false; + user_connect_callback_ = NULL; + user_read_callback_ = NULL; + user_write_callback_ = NULL; + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + server_cert_ = NULL; + if (server_cert_nss_) { + CERT_DestroyCertificate(server_cert_nss_); + server_cert_nss_ = NULL; + } + local_server_cert_verify_result_.Reset(); + server_cert_verify_result_ = NULL; + ssl_connection_status_ = 0; + completed_handshake_ = false; + pseudo_connected_ = false; + eset_mitm_detected_ = false; + start_cert_verification_time_ = base::TimeTicks(); + predicted_cert_chain_correct_ = false; + peername_initialized_ = false; + nss_bufs_ = NULL; + client_certs_.clear(); + client_auth_cert_needed_ = false; + + LeaveFunction(""); +} + +bool SSLClientSocketNSS::IsConnected() const { + // Ideally, we should also check if we have received the close_notify alert + // message from the server, and return false in that case. We're not doing + // that, so this function may return a false positive. Since the upper + // layer (HttpNetworkTransaction) needs to handle a persistent connection + // closed by the server when we send a request anyway, a false positive in + // exchange for simpler code is a good trade-off. + EnterFunction(""); + bool ret = (pseudo_connected_ || completed_handshake_) && + transport_->socket()->IsConnected(); + LeaveFunction(""); + return ret; +} + +bool SSLClientSocketNSS::IsConnectedAndIdle() const { + // Unlike IsConnected, this method doesn't return a false positive. + // + // Strictly speaking, we should check if we have received the close_notify + // alert message from the server, and return false in that case. Although + // the close_notify alert message means EOF in the SSL layer, it is just + // bytes to the transport layer below, so + // transport_->socket()->IsConnectedAndIdle() returns the desired false + // when we receive close_notify. + EnterFunction(""); + bool ret = (pseudo_connected_ || completed_handshake_) && + transport_->socket()->IsConnectedAndIdle(); + LeaveFunction(""); + return ret; +} + +int SSLClientSocketNSS::GetPeerAddress(AddressList* address) const { + return transport_->socket()->GetPeerAddress(address); +} + +const BoundNetLog& SSLClientSocketNSS::NetLog() const { + return net_log_; +} + +void SSLClientSocketNSS::SetSubresourceSpeculation() { + if (transport_.get() && transport_->socket()) { + transport_->socket()->SetSubresourceSpeculation(); + } else { + NOTREACHED(); + } +} + +void SSLClientSocketNSS::SetOmniboxSpeculation() { + if (transport_.get() && transport_->socket()) { + transport_->socket()->SetOmniboxSpeculation(); + } else { + NOTREACHED(); + } +} + +bool SSLClientSocketNSS::WasEverUsed() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->WasEverUsed(); + } + NOTREACHED(); + return false; +} + +bool SSLClientSocketNSS::UsingTCPFastOpen() const { + if (transport_.get() && transport_->socket()) { + return transport_->socket()->UsingTCPFastOpen(); + } + NOTREACHED(); + return false; +} + +int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + EnterFunction(buf_len); + DCHECK(!user_read_callback_); + DCHECK(!user_connect_callback_); + DCHECK(!user_read_buf_); + DCHECK(nss_bufs_); + + user_read_buf_ = buf; + user_read_buf_len_ = buf_len; + + if (!completed_handshake_) { + // In this case we have lied about being connected in order to merge the + // first Write into a Snap Start handshake. We'll leave the read hanging + // until the handshake has completed. + DCHECK(pseudo_connected_); + + user_read_callback_ = callback; + LeaveFunction(ERR_IO_PENDING); + return ERR_IO_PENDING; + } + + int rv = DoReadLoop(OK); + + if (rv == ERR_IO_PENDING) { + user_read_callback_ = callback; + } else { + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + } + LeaveFunction(rv); + return rv; +} + +int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len, + CompletionCallback* callback) { + EnterFunction(buf_len); + if (!pseudo_connected_) { + DCHECK(completed_handshake_); + DCHECK(next_handshake_state_ == STATE_NONE); + DCHECK(!user_connect_callback_); + } + DCHECK(!user_write_callback_); + DCHECK(!user_write_buf_); + DCHECK(nss_bufs_); + + user_write_buf_ = buf; + user_write_buf_len_ = buf_len; + + if (next_handshake_state_ == STATE_SNAP_START_WAIT_FOR_WRITE) { + // We lied about being connected and we have been waiting for this write in + // order to merge it into the Snap Start handshake. We'll leave the write + // pending until the handshake completes. + DCHECK(pseudo_connected_); + int rv = DoHandshakeLoop(OK); + if (rv == ERR_IO_PENDING) { + user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + if (rv != OK) + return rv; + } + + if (corked_) { + corked_ = false; + uncork_timer_.Reset(); + } + int rv = DoWriteLoop(OK); + + if (rv == ERR_IO_PENDING) { + user_write_callback_ = callback; + } else { + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + } + LeaveFunction(rv); + return rv; +} + +bool SSLClientSocketNSS::SetReceiveBufferSize(int32 size) { + return transport_->socket()->SetReceiveBufferSize(size); +} + +bool SSLClientSocketNSS::SetSendBufferSize(int32 size) { + return transport_->socket()->SetSendBufferSize(size); +} + +int SSLClientSocketNSS::Init() { + EnterFunction(""); + // Initialize the NSS SSL library in a threadsafe way. This also + // initializes the NSS base library. + EnsureNSSSSLInit(); + if (!NSS_IsInitialized()) + return ERR_UNEXPECTED; +#if !defined(OS_MACOSX) && !defined(OS_WIN) + // We must call EnsureOCSPInit() here, on the IO thread, to get the IO loop + // by MessageLoopForIO::current(). + // X509Certificate::Verify() runs on a worker thread of CertVerifier. + EnsureOCSPInit(); +#endif + + LeaveFunction(""); + return OK; +} + int SSLClientSocketNSS::InitializeSSLOptions() { // Transport connected, now hook it up to nss // TODO(port): specify rx and tx buffer sizes separately @@ -764,214 +1080,6 @@ int SSLClientSocketNSS::InitializeSSLPeerName() { return OK; } -void SSLClientSocketNSS::Disconnect() { - EnterFunction(""); - - // TODO(wtc): Send SSL close_notify alert. - if (nss_fd_ != NULL) { - PR_Close(nss_fd_); - nss_fd_ = NULL; - } - - // Shut down anything that may call us back (through buffer_send_callback_, - // buffer_recv_callback, or handshake_io_callback_). - verifier_.reset(); - transport_->socket()->Disconnect(); - - // Reset object state - transport_send_busy_ = false; - transport_recv_busy_ = false; - user_connect_callback_ = NULL; - user_read_callback_ = NULL; - user_write_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - server_cert_ = NULL; - if (server_cert_nss_) { - CERT_DestroyCertificate(server_cert_nss_); - server_cert_nss_ = NULL; - } - local_server_cert_verify_result_.Reset(); - server_cert_verify_result_ = NULL; - ssl_connection_status_ = 0; - completed_handshake_ = false; - pseudo_connected_ = false; - eset_mitm_detected_ = false; - start_cert_verification_time_ = base::TimeTicks(); - predicted_cert_chain_correct_ = false; - peername_initialized_ = false; - nss_bufs_ = NULL; - client_certs_.clear(); - client_auth_cert_needed_ = false; - - LeaveFunction(""); -} - -bool SSLClientSocketNSS::IsConnected() const { - // Ideally, we should also check if we have received the close_notify alert - // message from the server, and return false in that case. We're not doing - // that, so this function may return a false positive. Since the upper - // layer (HttpNetworkTransaction) needs to handle a persistent connection - // closed by the server when we send a request anyway, a false positive in - // exchange for simpler code is a good trade-off. - EnterFunction(""); - bool ret = (pseudo_connected_ || completed_handshake_) && - transport_->socket()->IsConnected(); - LeaveFunction(""); - return ret; -} - -bool SSLClientSocketNSS::IsConnectedAndIdle() const { - // Unlike IsConnected, this method doesn't return a false positive. - // - // Strictly speaking, we should check if we have received the close_notify - // alert message from the server, and return false in that case. Although - // the close_notify alert message means EOF in the SSL layer, it is just - // bytes to the transport layer below, so - // transport_->socket()->IsConnectedAndIdle() returns the desired false - // when we receive close_notify. - EnterFunction(""); - bool ret = (pseudo_connected_ || completed_handshake_) && - transport_->socket()->IsConnectedAndIdle(); - LeaveFunction(""); - return ret; -} - -int SSLClientSocketNSS::GetPeerAddress(AddressList* address) const { - return transport_->socket()->GetPeerAddress(address); -} - -const BoundNetLog& SSLClientSocketNSS::NetLog() const { - return net_log_; -} - -void SSLClientSocketNSS::SetSubresourceSpeculation() { - if (transport_.get() && transport_->socket()) { - transport_->socket()->SetSubresourceSpeculation(); - } else { - NOTREACHED(); - } -} - -void SSLClientSocketNSS::SetOmniboxSpeculation() { - if (transport_.get() && transport_->socket()) { - transport_->socket()->SetOmniboxSpeculation(); - } else { - NOTREACHED(); - } -} - -bool SSLClientSocketNSS::WasEverUsed() const { - if (transport_.get() && transport_->socket()) { - return transport_->socket()->WasEverUsed(); - } - NOTREACHED(); - return false; -} - -bool SSLClientSocketNSS::UsingTCPFastOpen() const { - if (transport_.get() && transport_->socket()) { - return transport_->socket()->UsingTCPFastOpen(); - } - NOTREACHED(); - return false; -} - -int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, - CompletionCallback* callback) { - EnterFunction(buf_len); - DCHECK(!user_read_callback_); - DCHECK(!user_connect_callback_); - DCHECK(!user_read_buf_); - DCHECK(nss_bufs_); - - user_read_buf_ = buf; - user_read_buf_len_ = buf_len; - - if (!completed_handshake_) { - // In this case we have lied about being connected in order to merge the - // first Write into a Snap Start handshake. We'll leave the read hanging - // until the handshake has completed. - DCHECK(pseudo_connected_); - - user_read_callback_ = callback; - LeaveFunction(ERR_IO_PENDING); - return ERR_IO_PENDING; - } - - int rv = DoReadLoop(OK); - - if (rv == ERR_IO_PENDING) { - user_read_callback_ = callback; - } else { - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - } - LeaveFunction(rv); - return rv; -} - -int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len, - CompletionCallback* callback) { - EnterFunction(buf_len); - if (!pseudo_connected_) { - DCHECK(completed_handshake_); - DCHECK(next_handshake_state_ == STATE_NONE); - DCHECK(!user_connect_callback_); - } - DCHECK(!user_write_callback_); - DCHECK(!user_write_buf_); - DCHECK(nss_bufs_); - - user_write_buf_ = buf; - user_write_buf_len_ = buf_len; - - if (next_handshake_state_ == STATE_SNAP_START_WAIT_FOR_WRITE) { - // We lied about being connected and we have been waiting for this write in - // order to merge it into the Snap Start handshake. We'll leave the write - // pending until the handshake completes. - DCHECK(pseudo_connected_); - int rv = DoHandshakeLoop(OK); - if (rv == ERR_IO_PENDING) { - user_write_callback_ = callback; - } else { - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - } - if (rv != OK) - return rv; - } - - if (corked_) { - corked_ = false; - uncork_timer_.Reset(); - } - int rv = DoWriteLoop(OK); - - if (rv == ERR_IO_PENDING) { - user_write_callback_ = callback; - } else { - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - } - LeaveFunction(rv); - return rv; -} - -bool SSLClientSocketNSS::SetReceiveBufferSize(int32 size) { - return transport_->socket()->SetReceiveBufferSize(size); -} - -bool SSLClientSocketNSS::SetSendBufferSize(int32 size) { - return transport_->socket()->SetSendBufferSize(size); -} - -// static -void SSLClientSocketNSS::ClearSessionCache() { - SSL_ClearSessionCache(); -} // Sets server_cert_ and server_cert_nss_ if not yet set. // Returns server_cert_. @@ -1051,91 +1159,6 @@ void SSLClientSocketNSS::UpdateConnectionStatus() { ssl_connection_status_ |= SSL_CONNECTION_SSL3_FALLBACK; } -void SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { - EnterFunction(""); - ssl_info->Reset(); - - if (!server_cert_) { - LOG(DFATAL) << "!server_cert_"; - return; - } - - ssl_info->cert_status = server_cert_verify_result_->cert_status; - DCHECK(server_cert_ != NULL); - ssl_info->cert = server_cert_; - ssl_info->connection_status = ssl_connection_status_; - - PRUint16 cipher_suite = - SSLConnectionStatusToCipherSuite(ssl_connection_status_); - SSLCipherSuiteInfo cipher_info; - SECStatus ok = SSL_GetCipherSuiteInfo(cipher_suite, - &cipher_info, sizeof(cipher_info)); - if (ok == SECSuccess) { - ssl_info->security_bits = cipher_info.effectiveKeyBits; - } else { - ssl_info->security_bits = -1; - LOG(DFATAL) << "SSL_GetCipherSuiteInfo returned " << PR_GetError() - << " for cipherSuite " << cipher_suite; - } - LeaveFunction(""); -} - -void SSLClientSocketNSS::GetSSLCertRequestInfo( - SSLCertRequestInfo* cert_request_info) { - EnterFunction(""); - // TODO(rch): switch SSLCertRequestInfo.host_and_port to a HostPortPair - cert_request_info->host_and_port = host_and_port_.ToString(); - cert_request_info->client_certs = client_certs_; - LeaveFunction(cert_request_info->client_certs.size()); -} - -SSLClientSocket::NextProtoStatus -SSLClientSocketNSS::GetNextProto(std::string* proto) { -#if defined(SSL_NEXT_PROTO_NEGOTIATED) - if (!handshake_callback_called_) { - DCHECK(pseudo_connected_); - predicted_npn_proto_used_ = true; - *proto = predicted_npn_proto_; - return predicted_npn_status_; - } - - unsigned char buf[255]; - int state; - unsigned len; - SECStatus rv = SSL_GetNextProto(nss_fd_, &state, buf, &len, sizeof(buf)); - if (rv != SECSuccess) { - NOTREACHED() << "Error return from SSL_GetNextProto: " << rv; - proto->clear(); - return kNextProtoUnsupported; - } - // We don't check for truncation because sizeof(buf) is large enough to hold - // the maximum protocol size. - switch (state) { - case SSL_NEXT_PROTO_NO_SUPPORT: - proto->clear(); - return kNextProtoUnsupported; - case SSL_NEXT_PROTO_NEGOTIATED: - *proto = std::string(reinterpret_cast<char*>(buf), len); - return kNextProtoNegotiated; - case SSL_NEXT_PROTO_NO_OVERLAP: - *proto = std::string(reinterpret_cast<char*>(buf), len); - return kNextProtoNoOverlap; - default: - NOTREACHED() << "Unknown status from SSL_GetNextProto: " << state; - proto->clear(); - return kNextProtoUnsupported; - } -#else - // No NPN support in the libssl that we are building with. - proto->clear(); - return kNextProtoUnsupported; -#endif -} - -void SSLClientSocketNSS::UseDNSSEC(DNSSECProvider* provider) { - dnssec_provider_ = provider; -} - void SSLClientSocketNSS::DoReadCallback(int rv) { EnterFunction(rv); DCHECK(rv != ERR_IO_PENDING); @@ -1250,109 +1273,6 @@ void SSLClientSocketNSS::OnRecvComplete(int result) { LeaveFunction(""); } -// Do network I/O between the given buffer and the given socket. -// Return true if some I/O performed, false otherwise (error or ERR_IO_PENDING) -bool SSLClientSocketNSS::DoTransportIO() { - EnterFunction(""); - bool network_moved = false; - if (nss_bufs_ != NULL) { - int nsent = BufferSend(); - int nreceived = BufferRecv(); - network_moved = (nsent > 0 || nreceived >= 0); - } - LeaveFunction(network_moved); - return network_moved; -} - -// Return 0 for EOF, -// > 0 for bytes transferred immediately, -// < 0 for error (or the non-error ERR_IO_PENDING). -int SSLClientSocketNSS::BufferSend(void) { - if (transport_send_busy_) - return ERR_IO_PENDING; - - EnterFunction(""); - const char* buf1; - const char* buf2; - unsigned int len1, len2; - memio_GetWriteParams(nss_bufs_, &buf1, &len1, &buf2, &len2); - const unsigned int len = len1 + len2; - - if (corked_ && len < kRecvBufferSize / 2) - return 0; - - int rv = 0; - if (len) { - scoped_refptr<IOBuffer> send_buffer(new IOBuffer(len)); - memcpy(send_buffer->data(), buf1, len1); - memcpy(send_buffer->data() + len1, buf2, len2); - rv = transport_->socket()->Write(send_buffer, len, - &buffer_send_callback_); - if (rv == ERR_IO_PENDING) { - transport_send_busy_ = true; - } else { - memio_PutWriteResult(nss_bufs_, MapErrorToNSS(rv)); - } - } - - LeaveFunction(rv); - return rv; -} - -void SSLClientSocketNSS::BufferSendComplete(int result) { - EnterFunction(result); - - // In the case of TCP FastOpen, connect is now finished. - if (!peername_initialized_ && UsingTCPFastOpen()) - InitializeSSLPeerName(); - - memio_PutWriteResult(nss_bufs_, MapErrorToNSS(result)); - transport_send_busy_ = false; - OnSendComplete(result); - LeaveFunction(""); -} - - -int SSLClientSocketNSS::BufferRecv(void) { - if (transport_recv_busy_) return ERR_IO_PENDING; - - char *buf; - int nb = memio_GetReadParams(nss_bufs_, &buf); - EnterFunction(nb); - int rv; - if (!nb) { - // buffer too full to read into, so no I/O possible at moment - rv = ERR_IO_PENDING; - } else { - recv_buffer_ = new IOBuffer(nb); - rv = transport_->socket()->Read(recv_buffer_, nb, &buffer_recv_callback_); - if (rv == ERR_IO_PENDING) { - transport_recv_busy_ = true; - } else { - if (rv > 0) - memcpy(buf, recv_buffer_->data(), rv); - memio_PutReadResult(nss_bufs_, MapErrorToNSS(rv)); - recv_buffer_ = NULL; - } - } - LeaveFunction(rv); - return rv; -} - -void SSLClientSocketNSS::BufferRecvComplete(int result) { - EnterFunction(result); - if (result > 0) { - char *buf; - memio_GetReadParams(nss_bufs_, &buf); - memcpy(buf, recv_buffer_->data(), result); - } - recv_buffer_ = NULL; - memio_PutReadResult(nss_bufs_, MapErrorToNSS(result)); - transport_recv_busy_ = false; - OnRecvComplete(result); - LeaveFunction(""); -} - int SSLClientSocketNSS::DoHandshakeLoop(int last_io_result) { EnterFunction(last_io_result); bool network_moved; @@ -1459,447 +1379,6 @@ int SSLClientSocketNSS::DoWriteLoop(int result) { return rv; } -// static -// NSS calls this if an incoming certificate needs to be verified. -// Do nothing but return SECSuccess. -// This is called only in full handshake mode. -// Peer certificate is retrieved in HandshakeCallback() later, which is called -// in full handshake mode or in resumption handshake mode. -SECStatus SSLClientSocketNSS::OwnAuthCertHandler(void* arg, - PRFileDesc* socket, - PRBool checksig, - PRBool is_server) { -#ifdef SSL_ENABLE_FALSE_START - // In the event that we are False Starting this connection, we wish to send - // out the Finished message and first application data record in the same - // packet. This prevents non-determinism when talking to False Start - // intolerant servers which, otherwise, might see the two messages in - // different reads or not, depending on network conditions. - PRBool false_start = 0; - SECStatus rv = SSL_OptionGet(socket, SSL_ENABLE_FALSE_START, &false_start); - DCHECK_EQ(SECSuccess, rv); - - if (false_start) { - SSLClientSocketNSS* that = reinterpret_cast<SSLClientSocketNSS*>(arg); - - // ESET anti-virus is capable of intercepting HTTPS connections on Windows. - // However, it is False Start intolerant and causes the connections to hang - // forever. We detect ESET by the issuer of the leaf certificate and set a - // flag to return a specific error, giving the user instructions for - // reconfiguring ESET. - CERTCertificate* cert = SSL_PeerCertificate(that->nss_fd_); - if (cert) { - char* common_name = CERT_GetCommonName(&cert->issuer); - if (common_name) { - if (strcmp(common_name, "ESET_RootSslCert") == 0) - that->eset_mitm_detected_ = true; - if (strcmp(common_name, - "ContentWatch Root Certificate Authority") == 0) { - // This is NetNanny. NetNanny are updating their product so we - // silently disable False Start for now. - rv = SSL_OptionSet(socket, SSL_ENABLE_FALSE_START, PR_FALSE); - DCHECK_EQ(SECSuccess, rv); - false_start = 0; - } - PORT_Free(common_name); - } - CERT_DestroyCertificate(cert); - } - - if (false_start && !that->handshake_callback_called_) { - that->corked_ = true; - that->uncork_timer_.Start( - base::TimeDelta::FromMilliseconds(kCorkTimeoutMs), - that, &SSLClientSocketNSS::UncorkAfterTimeout); - } - } -#endif - - // Tell NSS to not verify the certificate. - return SECSuccess; -} - -#if defined(NSS_PLATFORM_CLIENT_AUTH) -// static -// NSS calls this if a client certificate is needed. -SECStatus SSLClientSocketNSS::PlatformClientAuthHandler( - void* arg, - PRFileDesc* socket, - CERTDistNames* ca_names, - CERTCertList** result_certs, - void** result_private_key) { - SSLClientSocketNSS* that = reinterpret_cast<SSLClientSocketNSS*>(arg); - - that->client_auth_cert_needed_ = !that->ssl_config_.send_client_cert; -#if defined(OS_WIN) - if (that->ssl_config_.send_client_cert) { - if (that->ssl_config_.client_cert) { - PCCERT_CONTEXT cert_context = - that->ssl_config_.client_cert->os_cert_handle(); - if (VLOG_IS_ON(1)) { - do { - DWORD size_needed = 0; - BOOL got_info = CertGetCertificateContextProperty( - cert_context, CERT_KEY_PROV_INFO_PROP_ID, NULL, &size_needed); - if (!got_info) { - VLOG(1) << "Failed to get key prov info size " << GetLastError(); - break; - } - std::vector<BYTE> raw_info(size_needed); - got_info = CertGetCertificateContextProperty( - cert_context, CERT_KEY_PROV_INFO_PROP_ID, &raw_info[0], - &size_needed); - if (!got_info) { - VLOG(1) << "Failed to get key prov info " << GetLastError(); - break; - } - PCRYPT_KEY_PROV_INFO info = - reinterpret_cast<PCRYPT_KEY_PROV_INFO>(&raw_info[0]); - VLOG(1) << "Container Name: " << info->pwszContainerName - << "\nProvider Name: " << info->pwszProvName - << "\nProvider Type: " << info->dwProvType - << "\nFlags: " << info->dwFlags - << "\nProvider Param Count: " << info->cProvParam - << "\nKey Specifier: " << info->dwKeySpec; - } while (false); - - do { - DWORD size_needed = 0; - BOOL got_identifier = CertGetCertificateContextProperty( - cert_context, CERT_KEY_IDENTIFIER_PROP_ID, NULL, &size_needed); - if (!got_identifier) { - VLOG(1) << "Failed to get key identifier size " - << GetLastError(); - break; - } - std::vector<BYTE> raw_id(size_needed); - got_identifier = CertGetCertificateContextProperty( - cert_context, CERT_KEY_IDENTIFIER_PROP_ID, &raw_id[0], - &size_needed); - if (!got_identifier) { - VLOG(1) << "Failed to get key identifier " << GetLastError(); - break; - } - VLOG(1) << "Key Identifier: " << base::HexEncode(&raw_id[0], - size_needed); - } while (false); - } - HCRYPTPROV provider = NULL; - DWORD key_spec = AT_KEYEXCHANGE; - BOOL must_free = FALSE; - BOOL acquired_key = CryptAcquireCertificatePrivateKey( - cert_context, - CRYPT_ACQUIRE_CACHE_FLAG | CRYPT_ACQUIRE_COMPARE_KEY_FLAG, - NULL, &provider, &key_spec, &must_free); - if (acquired_key && provider) { - DCHECK_NE(key_spec, CERT_NCRYPT_KEY_SPEC); - - // The certificate cache may have been updated/used, in which case, - // duplicate the existing handle, since NSS will free it when no - // longer in use. - if (!must_free) - CryptContextAddRef(provider, NULL, 0); - - SECItem der_cert; - der_cert.type = siDERCertBuffer; - der_cert.data = cert_context->pbCertEncoded; - der_cert.len = cert_context->cbCertEncoded; - - // TODO(rsleevi): Error checking for NSS allocation errors. - *result_certs = CERT_NewCertList(); - CERTCertDBHandle* db_handle = CERT_GetDefaultCertDB(); - CERTCertificate* user_cert = CERT_NewTempCertificate( - db_handle, &der_cert, NULL, PR_FALSE, PR_TRUE); - CERT_AddCertToListTail(*result_certs, user_cert); - - // Add the intermediates. - X509Certificate::OSCertHandles intermediates = - that->ssl_config_.client_cert->GetIntermediateCertificates(); - for (X509Certificate::OSCertHandles::const_iterator it = - intermediates.begin(); it != intermediates.end(); ++it) { - der_cert.data = (*it)->pbCertEncoded; - der_cert.len = (*it)->cbCertEncoded; - - CERTCertificate* intermediate = CERT_NewTempCertificate( - db_handle, &der_cert, NULL, PR_FALSE, PR_TRUE); - CERT_AddCertToListTail(*result_certs, intermediate); - } - // TODO(wtc): |key_spec| should be passed along with |provider|. - *result_private_key = reinterpret_cast<void*>(provider); - return SECSuccess; - } - LOG(WARNING) << "Client cert found without private key"; - } - // Send no client certificate. - return SECFailure; - } - - that->client_certs_.clear(); - - std::vector<CERT_NAME_BLOB> issuer_list(ca_names->nnames); - for (int i = 0; i < ca_names->nnames; ++i) { - issuer_list[i].cbData = ca_names->names[i].len; - issuer_list[i].pbData = ca_names->names[i].data; - } - - // Client certificates of the user are in the "MY" system certificate store. - HCERTSTORE my_cert_store = CertOpenSystemStore(NULL, L"MY"); - if (!my_cert_store) { - LOG(ERROR) << "Could not open the \"MY\" system certificate store: " - << GetLastError(); - return SECFailure; - } - - // Enumerate the client certificates. - CERT_CHAIN_FIND_BY_ISSUER_PARA find_by_issuer_para; - memset(&find_by_issuer_para, 0, sizeof(find_by_issuer_para)); - find_by_issuer_para.cbSize = sizeof(find_by_issuer_para); - find_by_issuer_para.pszUsageIdentifier = szOID_PKIX_KP_CLIENT_AUTH; - find_by_issuer_para.cIssuer = ca_names->nnames; - find_by_issuer_para.rgIssuer = ca_names->nnames ? &issuer_list[0] : NULL; - find_by_issuer_para.pfnFindCallback = ClientCertFindCallback; - - PCCERT_CHAIN_CONTEXT chain_context = NULL; - - for (;;) { - // Find a certificate chain. - chain_context = CertFindChainInStore(my_cert_store, - X509_ASN_ENCODING, - 0, - CERT_CHAIN_FIND_BY_ISSUER, - &find_by_issuer_para, - chain_context); - if (!chain_context) { - DWORD err = GetLastError(); - if (err != CRYPT_E_NOT_FOUND) - DLOG(ERROR) << "CertFindChainInStore failed: " << err; - break; - } - - // Get the leaf certificate. - PCCERT_CONTEXT cert_context = - chain_context->rgpChain[0]->rgpElement[0]->pCertContext; - // Copy it to our own certificate store, so that we can close the "MY" - // certificate store before returning from this function. - PCCERT_CONTEXT cert_context2; - BOOL ok = CertAddCertificateContextToStore(X509Certificate::cert_store(), - cert_context, - CERT_STORE_ADD_USE_EXISTING, - &cert_context2); - if (!ok) { - NOTREACHED(); - continue; - } - - // Copy the rest of the chain to our own store as well. Copying the chain - // stops gracefully if an error is encountered, with the partial chain - // being used as the intermediates, rather than failing to consider the - // client certificate. - net::X509Certificate::OSCertHandles intermediates; - for (DWORD i = 1; i < chain_context->rgpChain[0]->cElement; i++) { - PCCERT_CONTEXT intermediate_copy; - ok = CertAddCertificateContextToStore(X509Certificate::cert_store(), - chain_context->rgpChain[0]->rgpElement[i]->pCertContext, - CERT_STORE_ADD_USE_EXISTING, &intermediate_copy); - if (!ok) { - NOTREACHED(); - break; - } - intermediates.push_back(intermediate_copy); - } - - scoped_refptr<X509Certificate> cert = X509Certificate::CreateFromHandle( - cert_context2, X509Certificate::SOURCE_LONE_CERT_IMPORT, - intermediates); - that->client_certs_.push_back(cert); - - X509Certificate::FreeOSCertHandle(cert_context2); - for (net::X509Certificate::OSCertHandles::iterator it = - intermediates.begin(); it != intermediates.end(); ++it) { - net::X509Certificate::FreeOSCertHandle(*it); - } - } - - BOOL ok = CertCloseStore(my_cert_store, CERT_CLOSE_STORE_CHECK_FLAG); - DCHECK(ok); - - // Tell NSS to suspend the client authentication. We will then abort the - // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED. - return SECWouldBlock; -#elif defined(OS_MACOSX) - if (that->ssl_config_.send_client_cert) { - if (that->ssl_config_.client_cert) { - OSStatus os_error = noErr; - SecIdentityRef identity = NULL; - SecKeyRef private_key = NULL; - CFArrayRef chain = - that->ssl_config_.client_cert->CreateClientCertificateChain(); - if (chain) { - identity = reinterpret_cast<SecIdentityRef>( - const_cast<void*>(CFArrayGetValueAtIndex(chain, 0))); - } - if (identity) - os_error = SecIdentityCopyPrivateKey(identity, &private_key); - - if (chain && identity && os_error == noErr) { - // TODO(rsleevi): Error checking for NSS allocation errors. - *result_certs = CERT_NewCertList(); - *result_private_key = reinterpret_cast<void*>(private_key); - - for (CFIndex i = 0; i < CFArrayGetCount(chain); ++i) { - CSSM_DATA cert_data; - SecCertificateRef cert_ref; - if (i == 0) { - cert_ref = that->ssl_config_.client_cert->os_cert_handle(); - } else { - cert_ref = reinterpret_cast<SecCertificateRef>( - const_cast<void*>(CFArrayGetValueAtIndex(chain, i))); - } - os_error = SecCertificateGetData(cert_ref, &cert_data); - if (os_error != noErr) - break; - - SECItem der_cert; - der_cert.type = siDERCertBuffer; - der_cert.data = cert_data.Data; - der_cert.len = cert_data.Length; - CERTCertificate* nss_cert = CERT_NewTempCertificate( - CERT_GetDefaultCertDB(), &der_cert, NULL, PR_FALSE, PR_TRUE); - CERT_AddCertToListTail(*result_certs, nss_cert); - } - } - if (os_error == noErr) { - CFRelease(chain); - return SECSuccess; - } - LOG(WARNING) << "Client cert found, but could not be used: " - << os_error; - if (*result_certs) { - CERT_DestroyCertList(*result_certs); - *result_certs = NULL; - } - if (*result_private_key) - *result_private_key = NULL; - if (private_key) - CFRelease(private_key); - if (chain) - CFRelease(chain); - } - // Send no client certificate. - return SECFailure; - } - - that->client_certs_.clear(); - - // First, get the cert issuer names allowed by the server. - std::vector<CertPrincipal> valid_issuers; - int n = ca_names->nnames; - for (int i = 0; i < n; i++) { - // Parse each name into a CertPrincipal object. - CertPrincipal p; - if (p.ParseDistinguishedName(ca_names->names[i].data, - ca_names->names[i].len)) { - valid_issuers.push_back(p); - } - } - - // Now get the available client certs whose issuers are allowed by the server. - X509Certificate::GetSSLClientCertificates(that->host_and_port_.host(), - valid_issuers, - &that->client_certs_); - - // Tell NSS to suspend the client authentication. We will then abort the - // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED. - return SECWouldBlock; -#else - return SECFailure; -#endif -} - -#else // NSS_PLATFORM_CLIENT_AUTH - -// static -// NSS calls this if a client certificate is needed. -// Based on Mozilla's NSS_GetClientAuthData. -SECStatus SSLClientSocketNSS::ClientAuthHandler( - void* arg, - PRFileDesc* socket, - CERTDistNames* ca_names, - CERTCertificate** result_certificate, - SECKEYPrivateKey** result_private_key) { - SSLClientSocketNSS* that = reinterpret_cast<SSLClientSocketNSS*>(arg); - - that->client_auth_cert_needed_ = !that->ssl_config_.send_client_cert; - void* wincx = SSL_RevealPinArg(socket); - - // Second pass: a client certificate should have been selected. - if (that->ssl_config_.send_client_cert) { - if (that->ssl_config_.client_cert) { - CERTCertificate* cert = CERT_DupCertificate( - that->ssl_config_.client_cert->os_cert_handle()); - SECKEYPrivateKey* privkey = PK11_FindKeyByAnyCert(cert, wincx); - if (privkey) { - // TODO(jsorianopastor): We should wait for server certificate - // verification before sending our credentials. See - // http://crbug.com/13934. - *result_certificate = cert; - *result_private_key = privkey; - return SECSuccess; - } - LOG(WARNING) << "Client cert found without private key"; - } - // Send no client certificate. - return SECFailure; - } - - // Iterate over all client certificates. - CERTCertList* client_certs = CERT_FindUserCertsByUsage( - CERT_GetDefaultCertDB(), certUsageSSLClient, - PR_FALSE, PR_FALSE, wincx); - if (client_certs) { - for (CERTCertListNode* node = CERT_LIST_HEAD(client_certs); - !CERT_LIST_END(node, client_certs); - node = CERT_LIST_NEXT(node)) { - // Only offer unexpired certificates. - if (CERT_CheckCertValidTimes(node->cert, PR_Now(), PR_TRUE) != - secCertTimeValid) - continue; - // Filter by issuer. - // - // TODO(davidben): This does a binary comparison of the DER-encoded - // issuers. We should match according to RFC 5280 sec. 7.1. We should find - // an appropriate NSS function or add one if needbe. - if (ca_names->nnames && - NSS_CmpCertChainWCANames(node->cert, ca_names) != SECSuccess) - continue; - X509Certificate* x509_cert = X509Certificate::CreateFromHandle( - node->cert, X509Certificate::SOURCE_LONE_CERT_IMPORT, - net::X509Certificate::OSCertHandles()); - that->client_certs_.push_back(x509_cert); - } - CERT_DestroyCertList(client_certs); - } - - // Tell NSS to suspend the client authentication. We will then abort the - // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED. - return SECWouldBlock; -} -#endif // NSS_PLATFORM_CLIENT_AUTH - -// static -// NSS calls this when handshake is completed. -// After the SSL handshake is finished, use CertVerifier to verify -// the saved server certificate. -void SSLClientSocketNSS::HandshakeCallback(PRFileDesc* socket, - void* arg) { - SSLClientSocketNSS* that = reinterpret_cast<SSLClientSocketNSS*>(arg); - - that->handshake_callback_called_ = true; - - that->UpdateServerCert(); - that->UpdateConnectionStatus(); -} - int SSLClientSocketNSS::DoSnapStartLoadInfo() { EnterFunction(""); int rv = ssl_host_info_->WaitForDataReady(&handshake_io_callback_); @@ -2083,158 +1562,6 @@ int SSLClientSocketNSS::DoHandshake() { return net_error; } -// DNSValidationResult enumerates the possible outcomes from processing a -// set of DNS records. -enum DNSValidationResult { - DNSVR_SUCCESS, // the cert is immediately acceptable. - DNSVR_FAILURE, // the cert is unconditionally rejected. - DNSVR_CONTINUE, // perform CA validation as usual. -}; - -// VerifyTXTRecords processes the RRDATA for a number of DNS TXT records and -// checks them against the given certificate. -// dnssec: if true then the TXT records are DNSSEC validated. In this case, -// DNSVR_SUCCESS may be returned. -// server_cert_nss: the certificate to validate -// rrdatas: the TXT records for the current domain. -static DNSValidationResult VerifyTXTRecords( - bool dnssec, - CERTCertificate* server_cert_nss, - const std::vector<base::StringPiece>& rrdatas) { - bool found_well_formed_record = false; - bool matched_record = false; - - for (std::vector<base::StringPiece>::const_iterator - i = rrdatas.begin(); i != rrdatas.end(); ++i) { - std::map<std::string, std::string> m( - DNSSECChainVerifier::ParseTLSTXTRecord(*i)); - if (m.empty()) - continue; - - std::map<std::string, std::string>::const_iterator j; - j = m.find("v"); - if (j == m.end() || j->second != "tls1") - continue; - - j = m.find("ha"); - - HASH_HashType hash_algorithm; - unsigned hash_length; - if (j == m.end() || j->second == "sha1") { - hash_algorithm = HASH_AlgSHA1; - hash_length = SHA1_LENGTH; - } else if (j->second == "sha256") { - hash_algorithm = HASH_AlgSHA256; - hash_length = SHA256_LENGTH; - } else { - continue; - } - - j = m.find("h"); - if (j == m.end()) - continue; - - std::vector<uint8> given_hash; - if (!base::HexStringToBytes(j->second, &given_hash)) - continue; - - if (given_hash.size() != hash_length) - continue; - - uint8 calculated_hash[SHA256_LENGTH]; // SHA256 is the largest. - SECStatus rv; - - j = m.find("hr"); - if (j == m.end() || j->second == "pubkey") { - rv = HASH_HashBuf(hash_algorithm, calculated_hash, - server_cert_nss->derPublicKey.data, - server_cert_nss->derPublicKey.len); - } else if (j->second == "cert") { - rv = HASH_HashBuf(hash_algorithm, calculated_hash, - server_cert_nss->derCert.data, - server_cert_nss->derCert.len); - } else { - continue; - } - - if (rv != SECSuccess) - NOTREACHED(); - - found_well_formed_record = true; - - if (memcmp(calculated_hash, &given_hash[0], hash_length) == 0) { - matched_record = true; - if (dnssec) - return DNSVR_SUCCESS; - } - } - - if (found_well_formed_record && !matched_record) - return DNSVR_FAILURE; - - return DNSVR_CONTINUE; -} - - -// CheckDNSSECChain tries to validate a DNSSEC chain embedded in -// |server_cert_nss_|. It returns true iff a chain is found that proves the -// value of a TXT record that contains a valid public key fingerprint. -static DNSValidationResult CheckDNSSECChain( - const std::string& hostname, - CERTCertificate* server_cert_nss) { - if (!server_cert_nss) - return DNSVR_CONTINUE; - - // CERT_FindCertExtensionByOID isn't exported so we have to install an OID, - // get a tag for it and find the extension by using that tag. - static SECOidTag dnssec_chain_tag; - static bool dnssec_chain_tag_valid; - if (!dnssec_chain_tag_valid) { - // It's harmless if multiple threads enter this block concurrently. - static const uint8 kDNSSECChainOID[] = - // 1.3.6.1.4.1.11129.2.1.4 - // (iso.org.dod.internet.private.enterprises.google.googleSecurity. - // certificateExtensions.dnssecEmbeddedChain) - {0x2b, 0x06, 0x01, 0x04, 0x01, 0xd6, 0x79, 0x02, 0x01, 0x04}; - SECOidData oid_data; - memset(&oid_data, 0, sizeof(oid_data)); - oid_data.oid.data = const_cast<uint8*>(kDNSSECChainOID); - oid_data.oid.len = sizeof(kDNSSECChainOID); - oid_data.desc = "DNSSEC chain"; - oid_data.supportedExtension = SUPPORTED_CERT_EXTENSION; - dnssec_chain_tag = SECOID_AddEntry(&oid_data); - DCHECK_NE(SEC_OID_UNKNOWN, dnssec_chain_tag); - dnssec_chain_tag_valid = true; - } - - SECItem dnssec_embedded_chain; - SECStatus rv = CERT_FindCertExtension(server_cert_nss, - dnssec_chain_tag, &dnssec_embedded_chain); - if (rv != SECSuccess) - return DNSVR_CONTINUE; - - base::StringPiece chain( - reinterpret_cast<char*>(dnssec_embedded_chain.data), - dnssec_embedded_chain.len); - std::string dns_hostname; - if (!DNSDomainFromDot(hostname, &dns_hostname)) - return DNSVR_CONTINUE; - DNSSECChainVerifier verifier(dns_hostname, chain); - DNSSECChainVerifier::Error err = verifier.Verify(); - if (err != DNSSECChainVerifier::OK) { - LOG(ERROR) << "DNSSEC chain verification failed: " << err; - return DNSVR_CONTINUE; - } - - if (verifier.rrtype() != kDNS_TXT) - return DNSVR_CONTINUE; - - DNSValidationResult r = VerifyTXTRecords( - true /* DNSSEC verified */, server_cert_nss, verifier.rrdatas()); - SECITEM_FreeItem(&dnssec_embedded_chain, PR_FALSE); - return r; -} - int SSLClientSocketNSS::DoVerifyDNSSEC(int result) { if (ssl_config_.dns_cert_provenance_checking_enabled && dns_cert_checker_) { @@ -2515,4 +1842,676 @@ void SSLClientSocketNSS::LogConnectionTypeMetrics() const { }; } +// SaveSnapStartInfo extracts the information needed to perform a Snap Start +// with this server in the future (if any) and tells |ssl_host_info_| to +// preserve it. +void SSLClientSocketNSS::SaveSnapStartInfo() { + if (!ssl_host_info_.get()) + return; + + // If the SSLHostInfo hasn't managed to load from disk yet then we can't save + // anything. + if (ssl_host_info_->WaitForDataReady(NULL) != OK) + return; + + SECStatus rv; + SSLSnapStartResult snap_start_type; + rv = SSL_GetSnapStartResult(nss_fd_, &snap_start_type); + if (rv != SECSuccess) { + NOTREACHED(); + return; + } + net_log_.AddEvent(NetLog::TYPE_SSL_SNAP_START, + new NetLogIntegerParameter("type", snap_start_type)); + if (snap_start_type == SSL_SNAP_START_FULL || + snap_start_type == SSL_SNAP_START_RESUME) { + // If we did a successful Snap Start then our information was correct and + // there's no point saving it again. + return; + } + + const unsigned char* hello_data; + unsigned hello_data_len; + rv = SSL_GetPredictedServerHelloData(nss_fd_, &hello_data, &hello_data_len); + if (rv != SECSuccess) { + NOTREACHED(); + return; + } + if (hello_data_len > std::numeric_limits<uint16>::max()) + return; + SSLHostInfo::State* state = ssl_host_info_->mutable_state(); + + if (hello_data_len > 0) { + state->server_hello = + std::string(reinterpret_cast<const char *>(hello_data), hello_data_len); + state->npn_valid = true; + state->npn_status = GetNextProto(&state->npn_protocol); + } else { + state->server_hello.clear(); + state->npn_valid = false; + } + + state->certs.clear(); + PeerCertificateChain certs(nss_fd_); + for (unsigned i = 0; i < certs.size(); i++) { + if (certs[i]->derCert.len > std::numeric_limits<uint16>::max()) + return; + + state->certs.push_back(std::string( + reinterpret_cast<char*>(certs[i]->derCert.data), + certs[i]->derCert.len)); + } + + ssl_host_info_->Persist(); +} + +// LoadSnapStartInfo parses |info|, which contains data previously serialised +// by |SaveSnapStartInfo|, and sets the predicted certificates and ServerHello +// data on the NSS socket. Returns true on success. If this function returns +// false, the caller should try a normal TLS handshake. +bool SSLClientSocketNSS::LoadSnapStartInfo() { + const SSLHostInfo::State& state(ssl_host_info_->state()); + + if (state.server_hello.empty() || + state.certs.empty() || + !state.npn_valid) { + return false; + } + + SECStatus rv; + rv = SSL_SetPredictedServerHelloData( + nss_fd_, + reinterpret_cast<const uint8*>(state.server_hello.data()), + state.server_hello.size()); + DCHECK_EQ(SECSuccess, rv); + + const std::vector<std::string>& certs_in = state.certs; + scoped_array<CERTCertificate*> certs(new CERTCertificate*[certs_in.size()]); + for (size_t i = 0; i < certs_in.size(); i++) { + SECItem derCert; + derCert.data = + const_cast<uint8*>(reinterpret_cast<const uint8*>(certs_in[i].data())); + derCert.len = certs_in[i].size(); + certs[i] = CERT_NewTempCertificate( + CERT_GetDefaultCertDB(), &derCert, NULL /* no nickname given */, + PR_FALSE /* not permanent */, PR_TRUE /* copy DER data */); + if (!certs[i]) { + DestroyCertificates(&certs[0], i); + NOTREACHED(); + return false; + } + } + + rv = SSL_SetPredictedPeerCertificates(nss_fd_, certs.get(), certs_in.size()); + DestroyCertificates(&certs[0], certs_in.size()); + DCHECK_EQ(SECSuccess, rv); + + if (state.npn_valid) { + predicted_npn_status_ = state.npn_status; + predicted_npn_proto_ = state.npn_protocol; + } + + return true; +} + +bool SSLClientSocketNSS::IsNPNProtocolMispredicted() { + DCHECK(handshake_callback_called_); + if (!predicted_npn_proto_used_) + return false; + std::string npn_proto; + GetNextProto(&npn_proto); + return predicted_npn_proto_ != npn_proto; +} + +void SSLClientSocketNSS::UncorkAfterTimeout() { + corked_ = false; + int nsent; + do { + nsent = BufferSend(); + } while (nsent > 0); +} + +// Do network I/O between the given buffer and the given socket. +// Return true if some I/O performed, false otherwise (error or ERR_IO_PENDING) +bool SSLClientSocketNSS::DoTransportIO() { + EnterFunction(""); + bool network_moved = false; + if (nss_bufs_ != NULL) { + int nsent = BufferSend(); + int nreceived = BufferRecv(); + network_moved = (nsent > 0 || nreceived >= 0); + } + LeaveFunction(network_moved); + return network_moved; +} + +// Return 0 for EOF, +// > 0 for bytes transferred immediately, +// < 0 for error (or the non-error ERR_IO_PENDING). +int SSLClientSocketNSS::BufferSend(void) { + if (transport_send_busy_) + return ERR_IO_PENDING; + + EnterFunction(""); + const char* buf1; + const char* buf2; + unsigned int len1, len2; + memio_GetWriteParams(nss_bufs_, &buf1, &len1, &buf2, &len2); + const unsigned int len = len1 + len2; + + if (corked_ && len < kRecvBufferSize / 2) + return 0; + + int rv = 0; + if (len) { + scoped_refptr<IOBuffer> send_buffer(new IOBuffer(len)); + memcpy(send_buffer->data(), buf1, len1); + memcpy(send_buffer->data() + len1, buf2, len2); + rv = transport_->socket()->Write(send_buffer, len, + &buffer_send_callback_); + if (rv == ERR_IO_PENDING) { + transport_send_busy_ = true; + } else { + memio_PutWriteResult(nss_bufs_, MapErrorToNSS(rv)); + } + } + + LeaveFunction(rv); + return rv; +} + +void SSLClientSocketNSS::BufferSendComplete(int result) { + EnterFunction(result); + + // In the case of TCP FastOpen, connect is now finished. + if (!peername_initialized_ && UsingTCPFastOpen()) + InitializeSSLPeerName(); + + memio_PutWriteResult(nss_bufs_, MapErrorToNSS(result)); + transport_send_busy_ = false; + OnSendComplete(result); + LeaveFunction(""); +} + +int SSLClientSocketNSS::BufferRecv(void) { + if (transport_recv_busy_) return ERR_IO_PENDING; + + char *buf; + int nb = memio_GetReadParams(nss_bufs_, &buf); + EnterFunction(nb); + int rv; + if (!nb) { + // buffer too full to read into, so no I/O possible at moment + rv = ERR_IO_PENDING; + } else { + recv_buffer_ = new IOBuffer(nb); + rv = transport_->socket()->Read(recv_buffer_, nb, &buffer_recv_callback_); + if (rv == ERR_IO_PENDING) { + transport_recv_busy_ = true; + } else { + if (rv > 0) + memcpy(buf, recv_buffer_->data(), rv); + memio_PutReadResult(nss_bufs_, MapErrorToNSS(rv)); + recv_buffer_ = NULL; + } + } + LeaveFunction(rv); + return rv; +} + +void SSLClientSocketNSS::BufferRecvComplete(int result) { + EnterFunction(result); + if (result > 0) { + char *buf; + memio_GetReadParams(nss_bufs_, &buf); + memcpy(buf, recv_buffer_->data(), result); + } + recv_buffer_ = NULL; + memio_PutReadResult(nss_bufs_, MapErrorToNSS(result)); + transport_recv_busy_ = false; + OnRecvComplete(result); + LeaveFunction(""); +} + +// static +// NSS calls this if an incoming certificate needs to be verified. +// Do nothing but return SECSuccess. +// This is called only in full handshake mode. +// Peer certificate is retrieved in HandshakeCallback() later, which is called +// in full handshake mode or in resumption handshake mode. +SECStatus SSLClientSocketNSS::OwnAuthCertHandler(void* arg, + PRFileDesc* socket, + PRBool checksig, + PRBool is_server) { +#ifdef SSL_ENABLE_FALSE_START + // In the event that we are False Starting this connection, we wish to send + // out the Finished message and first application data record in the same + // packet. This prevents non-determinism when talking to False Start + // intolerant servers which, otherwise, might see the two messages in + // different reads or not, depending on network conditions. + PRBool false_start = 0; + SECStatus rv = SSL_OptionGet(socket, SSL_ENABLE_FALSE_START, &false_start); + DCHECK_EQ(SECSuccess, rv); + + if (false_start) { + SSLClientSocketNSS* that = reinterpret_cast<SSLClientSocketNSS*>(arg); + + // ESET anti-virus is capable of intercepting HTTPS connections on Windows. + // However, it is False Start intolerant and causes the connections to hang + // forever. We detect ESET by the issuer of the leaf certificate and set a + // flag to return a specific error, giving the user instructions for + // reconfiguring ESET. + CERTCertificate* cert = SSL_PeerCertificate(that->nss_fd_); + if (cert) { + char* common_name = CERT_GetCommonName(&cert->issuer); + if (common_name) { + if (strcmp(common_name, "ESET_RootSslCert") == 0) + that->eset_mitm_detected_ = true; + if (strcmp(common_name, + "ContentWatch Root Certificate Authority") == 0) { + // This is NetNanny. NetNanny are updating their product so we + // silently disable False Start for now. + rv = SSL_OptionSet(socket, SSL_ENABLE_FALSE_START, PR_FALSE); + DCHECK_EQ(SECSuccess, rv); + false_start = 0; + } + PORT_Free(common_name); + } + CERT_DestroyCertificate(cert); + } + + if (false_start && !that->handshake_callback_called_) { + that->corked_ = true; + that->uncork_timer_.Start( + base::TimeDelta::FromMilliseconds(kCorkTimeoutMs), + that, &SSLClientSocketNSS::UncorkAfterTimeout); + } + } +#endif + + // Tell NSS to not verify the certificate. + return SECSuccess; +} + +#if defined(NSS_PLATFORM_CLIENT_AUTH) +// static +// NSS calls this if a client certificate is needed. +SECStatus SSLClientSocketNSS::PlatformClientAuthHandler( + void* arg, + PRFileDesc* socket, + CERTDistNames* ca_names, + CERTCertList** result_certs, + void** result_private_key) { + SSLClientSocketNSS* that = reinterpret_cast<SSLClientSocketNSS*>(arg); + + that->client_auth_cert_needed_ = !that->ssl_config_.send_client_cert; +#if defined(OS_WIN) + if (that->ssl_config_.send_client_cert) { + if (that->ssl_config_.client_cert) { + PCCERT_CONTEXT cert_context = + that->ssl_config_.client_cert->os_cert_handle(); + if (VLOG_IS_ON(1)) { + do { + DWORD size_needed = 0; + BOOL got_info = CertGetCertificateContextProperty( + cert_context, CERT_KEY_PROV_INFO_PROP_ID, NULL, &size_needed); + if (!got_info) { + VLOG(1) << "Failed to get key prov info size " << GetLastError(); + break; + } + std::vector<BYTE> raw_info(size_needed); + got_info = CertGetCertificateContextProperty( + cert_context, CERT_KEY_PROV_INFO_PROP_ID, &raw_info[0], + &size_needed); + if (!got_info) { + VLOG(1) << "Failed to get key prov info " << GetLastError(); + break; + } + PCRYPT_KEY_PROV_INFO info = + reinterpret_cast<PCRYPT_KEY_PROV_INFO>(&raw_info[0]); + VLOG(1) << "Container Name: " << info->pwszContainerName + << "\nProvider Name: " << info->pwszProvName + << "\nProvider Type: " << info->dwProvType + << "\nFlags: " << info->dwFlags + << "\nProvider Param Count: " << info->cProvParam + << "\nKey Specifier: " << info->dwKeySpec; + } while (false); + + do { + DWORD size_needed = 0; + BOOL got_identifier = CertGetCertificateContextProperty( + cert_context, CERT_KEY_IDENTIFIER_PROP_ID, NULL, &size_needed); + if (!got_identifier) { + VLOG(1) << "Failed to get key identifier size " + << GetLastError(); + break; + } + std::vector<BYTE> raw_id(size_needed); + got_identifier = CertGetCertificateContextProperty( + cert_context, CERT_KEY_IDENTIFIER_PROP_ID, &raw_id[0], + &size_needed); + if (!got_identifier) { + VLOG(1) << "Failed to get key identifier " << GetLastError(); + break; + } + VLOG(1) << "Key Identifier: " << base::HexEncode(&raw_id[0], + size_needed); + } while (false); + } + HCRYPTPROV provider = NULL; + DWORD key_spec = AT_KEYEXCHANGE; + BOOL must_free = FALSE; + BOOL acquired_key = CryptAcquireCertificatePrivateKey( + cert_context, + CRYPT_ACQUIRE_CACHE_FLAG | CRYPT_ACQUIRE_COMPARE_KEY_FLAG, + NULL, &provider, &key_spec, &must_free); + if (acquired_key && provider) { + DCHECK_NE(key_spec, CERT_NCRYPT_KEY_SPEC); + + // The certificate cache may have been updated/used, in which case, + // duplicate the existing handle, since NSS will free it when no + // longer in use. + if (!must_free) + CryptContextAddRef(provider, NULL, 0); + + SECItem der_cert; + der_cert.type = siDERCertBuffer; + der_cert.data = cert_context->pbCertEncoded; + der_cert.len = cert_context->cbCertEncoded; + + // TODO(rsleevi): Error checking for NSS allocation errors. + *result_certs = CERT_NewCertList(); + CERTCertDBHandle* db_handle = CERT_GetDefaultCertDB(); + CERTCertificate* user_cert = CERT_NewTempCertificate( + db_handle, &der_cert, NULL, PR_FALSE, PR_TRUE); + CERT_AddCertToListTail(*result_certs, user_cert); + + // Add the intermediates. + X509Certificate::OSCertHandles intermediates = + that->ssl_config_.client_cert->GetIntermediateCertificates(); + for (X509Certificate::OSCertHandles::const_iterator it = + intermediates.begin(); it != intermediates.end(); ++it) { + der_cert.data = (*it)->pbCertEncoded; + der_cert.len = (*it)->cbCertEncoded; + + CERTCertificate* intermediate = CERT_NewTempCertificate( + db_handle, &der_cert, NULL, PR_FALSE, PR_TRUE); + CERT_AddCertToListTail(*result_certs, intermediate); + } + // TODO(wtc): |key_spec| should be passed along with |provider|. + *result_private_key = reinterpret_cast<void*>(provider); + return SECSuccess; + } + LOG(WARNING) << "Client cert found without private key"; + } + // Send no client certificate. + return SECFailure; + } + + that->client_certs_.clear(); + + std::vector<CERT_NAME_BLOB> issuer_list(ca_names->nnames); + for (int i = 0; i < ca_names->nnames; ++i) { + issuer_list[i].cbData = ca_names->names[i].len; + issuer_list[i].pbData = ca_names->names[i].data; + } + + // Client certificates of the user are in the "MY" system certificate store. + HCERTSTORE my_cert_store = CertOpenSystemStore(NULL, L"MY"); + if (!my_cert_store) { + LOG(ERROR) << "Could not open the \"MY\" system certificate store: " + << GetLastError(); + return SECFailure; + } + + // Enumerate the client certificates. + CERT_CHAIN_FIND_BY_ISSUER_PARA find_by_issuer_para; + memset(&find_by_issuer_para, 0, sizeof(find_by_issuer_para)); + find_by_issuer_para.cbSize = sizeof(find_by_issuer_para); + find_by_issuer_para.pszUsageIdentifier = szOID_PKIX_KP_CLIENT_AUTH; + find_by_issuer_para.cIssuer = ca_names->nnames; + find_by_issuer_para.rgIssuer = ca_names->nnames ? &issuer_list[0] : NULL; + find_by_issuer_para.pfnFindCallback = ClientCertFindCallback; + + PCCERT_CHAIN_CONTEXT chain_context = NULL; + + for (;;) { + // Find a certificate chain. + chain_context = CertFindChainInStore(my_cert_store, + X509_ASN_ENCODING, + 0, + CERT_CHAIN_FIND_BY_ISSUER, + &find_by_issuer_para, + chain_context); + if (!chain_context) { + DWORD err = GetLastError(); + if (err != CRYPT_E_NOT_FOUND) + DLOG(ERROR) << "CertFindChainInStore failed: " << err; + break; + } + + // Get the leaf certificate. + PCCERT_CONTEXT cert_context = + chain_context->rgpChain[0]->rgpElement[0]->pCertContext; + // Copy it to our own certificate store, so that we can close the "MY" + // certificate store before returning from this function. + PCCERT_CONTEXT cert_context2; + BOOL ok = CertAddCertificateContextToStore(X509Certificate::cert_store(), + cert_context, + CERT_STORE_ADD_USE_EXISTING, + &cert_context2); + if (!ok) { + NOTREACHED(); + continue; + } + + // Copy the rest of the chain to our own store as well. Copying the chain + // stops gracefully if an error is encountered, with the partial chain + // being used as the intermediates, rather than failing to consider the + // client certificate. + net::X509Certificate::OSCertHandles intermediates; + for (DWORD i = 1; i < chain_context->rgpChain[0]->cElement; i++) { + PCCERT_CONTEXT intermediate_copy; + ok = CertAddCertificateContextToStore(X509Certificate::cert_store(), + chain_context->rgpChain[0]->rgpElement[i]->pCertContext, + CERT_STORE_ADD_USE_EXISTING, &intermediate_copy); + if (!ok) { + NOTREACHED(); + break; + } + intermediates.push_back(intermediate_copy); + } + + scoped_refptr<X509Certificate> cert = X509Certificate::CreateFromHandle( + cert_context2, X509Certificate::SOURCE_LONE_CERT_IMPORT, + intermediates); + that->client_certs_.push_back(cert); + + X509Certificate::FreeOSCertHandle(cert_context2); + for (net::X509Certificate::OSCertHandles::iterator it = + intermediates.begin(); it != intermediates.end(); ++it) { + net::X509Certificate::FreeOSCertHandle(*it); + } + } + + BOOL ok = CertCloseStore(my_cert_store, CERT_CLOSE_STORE_CHECK_FLAG); + DCHECK(ok); + + // Tell NSS to suspend the client authentication. We will then abort the + // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED. + return SECWouldBlock; +#elif defined(OS_MACOSX) + if (that->ssl_config_.send_client_cert) { + if (that->ssl_config_.client_cert) { + OSStatus os_error = noErr; + SecIdentityRef identity = NULL; + SecKeyRef private_key = NULL; + CFArrayRef chain = + that->ssl_config_.client_cert->CreateClientCertificateChain(); + if (chain) { + identity = reinterpret_cast<SecIdentityRef>( + const_cast<void*>(CFArrayGetValueAtIndex(chain, 0))); + } + if (identity) + os_error = SecIdentityCopyPrivateKey(identity, &private_key); + + if (chain && identity && os_error == noErr) { + // TODO(rsleevi): Error checking for NSS allocation errors. + *result_certs = CERT_NewCertList(); + *result_private_key = reinterpret_cast<void*>(private_key); + + for (CFIndex i = 0; i < CFArrayGetCount(chain); ++i) { + CSSM_DATA cert_data; + SecCertificateRef cert_ref; + if (i == 0) { + cert_ref = that->ssl_config_.client_cert->os_cert_handle(); + } else { + cert_ref = reinterpret_cast<SecCertificateRef>( + const_cast<void*>(CFArrayGetValueAtIndex(chain, i))); + } + os_error = SecCertificateGetData(cert_ref, &cert_data); + if (os_error != noErr) + break; + + SECItem der_cert; + der_cert.type = siDERCertBuffer; + der_cert.data = cert_data.Data; + der_cert.len = cert_data.Length; + CERTCertificate* nss_cert = CERT_NewTempCertificate( + CERT_GetDefaultCertDB(), &der_cert, NULL, PR_FALSE, PR_TRUE); + CERT_AddCertToListTail(*result_certs, nss_cert); + } + } + if (os_error == noErr) { + CFRelease(chain); + return SECSuccess; + } + LOG(WARNING) << "Client cert found, but could not be used: " + << os_error; + if (*result_certs) { + CERT_DestroyCertList(*result_certs); + *result_certs = NULL; + } + if (*result_private_key) + *result_private_key = NULL; + if (private_key) + CFRelease(private_key); + if (chain) + CFRelease(chain); + } + // Send no client certificate. + return SECFailure; + } + + that->client_certs_.clear(); + + // First, get the cert issuer names allowed by the server. + std::vector<CertPrincipal> valid_issuers; + int n = ca_names->nnames; + for (int i = 0; i < n; i++) { + // Parse each name into a CertPrincipal object. + CertPrincipal p; + if (p.ParseDistinguishedName(ca_names->names[i].data, + ca_names->names[i].len)) { + valid_issuers.push_back(p); + } + } + + // Now get the available client certs whose issuers are allowed by the server. + X509Certificate::GetSSLClientCertificates(that->host_and_port_.host(), + valid_issuers, + &that->client_certs_); + + // Tell NSS to suspend the client authentication. We will then abort the + // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED. + return SECWouldBlock; +#else + return SECFailure; +#endif +} + +#else // NSS_PLATFORM_CLIENT_AUTH + +// static +// NSS calls this if a client certificate is needed. +// Based on Mozilla's NSS_GetClientAuthData. +SECStatus SSLClientSocketNSS::ClientAuthHandler( + void* arg, + PRFileDesc* socket, + CERTDistNames* ca_names, + CERTCertificate** result_certificate, + SECKEYPrivateKey** result_private_key) { + SSLClientSocketNSS* that = reinterpret_cast<SSLClientSocketNSS*>(arg); + + that->client_auth_cert_needed_ = !that->ssl_config_.send_client_cert; + void* wincx = SSL_RevealPinArg(socket); + + // Second pass: a client certificate should have been selected. + if (that->ssl_config_.send_client_cert) { + if (that->ssl_config_.client_cert) { + CERTCertificate* cert = CERT_DupCertificate( + that->ssl_config_.client_cert->os_cert_handle()); + SECKEYPrivateKey* privkey = PK11_FindKeyByAnyCert(cert, wincx); + if (privkey) { + // TODO(jsorianopastor): We should wait for server certificate + // verification before sending our credentials. See + // http://crbug.com/13934. + *result_certificate = cert; + *result_private_key = privkey; + return SECSuccess; + } + LOG(WARNING) << "Client cert found without private key"; + } + // Send no client certificate. + return SECFailure; + } + + // Iterate over all client certificates. + CERTCertList* client_certs = CERT_FindUserCertsByUsage( + CERT_GetDefaultCertDB(), certUsageSSLClient, + PR_FALSE, PR_FALSE, wincx); + if (client_certs) { + for (CERTCertListNode* node = CERT_LIST_HEAD(client_certs); + !CERT_LIST_END(node, client_certs); + node = CERT_LIST_NEXT(node)) { + // Only offer unexpired certificates. + if (CERT_CheckCertValidTimes(node->cert, PR_Now(), PR_TRUE) != + secCertTimeValid) + continue; + // Filter by issuer. + // + // TODO(davidben): This does a binary comparison of the DER-encoded + // issuers. We should match according to RFC 5280 sec. 7.1. We should find + // an appropriate NSS function or add one if needbe. + if (ca_names->nnames && + NSS_CmpCertChainWCANames(node->cert, ca_names) != SECSuccess) + continue; + X509Certificate* x509_cert = X509Certificate::CreateFromHandle( + node->cert, X509Certificate::SOURCE_LONE_CERT_IMPORT, + net::X509Certificate::OSCertHandles()); + that->client_certs_.push_back(x509_cert); + } + CERT_DestroyCertList(client_certs); + } + + // Tell NSS to suspend the client authentication. We will then abort the + // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED. + return SECWouldBlock; +} +#endif // NSS_PLATFORM_CLIENT_AUTH + +// static +// NSS calls this when handshake is completed. +// After the SSL handshake is finished, use CertVerifier to verify +// the saved server certificate. +void SSLClientSocketNSS::HandshakeCallback(PRFileDesc* socket, + void* arg) { + SSLClientSocketNSS* that = reinterpret_cast<SSLClientSocketNSS*>(arg); + + that->handshake_callback_called_ = true; + + that->UpdateServerCert(); + that->UpdateConnectionStatus(); +} + } // namespace net diff --git a/net/socket/ssl_client_socket_nss.h b/net/socket/ssl_client_socket_nss.h index bca4166..f0e089c 100644 --- a/net/socket/ssl_client_socket_nss.h +++ b/net/socket/ssl_client_socket_nss.h @@ -53,6 +53,9 @@ class SSLClientSocketNSS : public SSLClientSocket { DnsCertProvenanceChecker* dnsrr_resolver); ~SSLClientSocketNSS(); + // For tests + static void ClearSessionCache(); + // SSLClientSocket methods: virtual void GetSSLInfo(SSLInfo* ssl_info); virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info); @@ -77,10 +80,20 @@ class SSLClientSocketNSS : public SSLClientSocket { virtual bool SetReceiveBufferSize(int32 size); virtual bool SetSendBufferSize(int32 size); - // For tests - static void ClearSessionCache(); - private: + enum State { + STATE_NONE, + STATE_SNAP_START_LOAD_INFO, + STATE_SNAP_START_WAIT_FOR_WRITE, + STATE_HANDSHAKE, + STATE_VERIFY_DNSSEC, + STATE_VERIFY_DNSSEC_COMPLETE, + STATE_VERIFY_CERT, + STATE_VERIFY_CERT_COMPLETE, + }; + + int Init(); + // Initializes NSS SSL options. Returns a net error code. int InitializeSSLOptions(); @@ -115,7 +128,6 @@ class SSLClientSocketNSS : public SSLClientSocket { int DoPayloadRead(); int DoPayloadWrite(); void LogConnectionTypeMetrics() const; - int Init(); void SaveSnapStartInfo(); bool LoadSnapStartInfo(); bool IsNPNProtocolMispredicted(); @@ -123,8 +135,8 @@ class SSLClientSocketNSS : public SSLClientSocket { bool DoTransportIO(); int BufferSend(void); - int BufferRecv(void); void BufferSendComplete(int result); + int BufferRecv(void); void BufferRecvComplete(int result); // NSS calls this when checking certificates. We pass 'this' as the first @@ -224,16 +236,6 @@ class SSLClientSocketNSS : public SSLClientSocket { // The time when we started waiting for DNSSEC records. base::Time dnssec_wait_start_time_; - enum State { - STATE_NONE, - STATE_SNAP_START_LOAD_INFO, - STATE_SNAP_START_WAIT_FOR_WRITE, - STATE_HANDSHAKE, - STATE_VERIFY_DNSSEC, - STATE_VERIFY_DNSSEC_COMPLETE, - STATE_VERIFY_CERT, - STATE_VERIFY_CERT_COMPLETE, - }; State next_handshake_state_; // The NSS SSL state machine diff --git a/net/socket/ssl_server_socket_nss.cc b/net/socket/ssl_server_socket_nss.cc index 283ba50..61284fb 100644 --- a/net/socket/ssl_server_socket_nss.cc +++ b/net/socket/ssl_server_socket_nss.cc @@ -90,22 +90,6 @@ SSLServerSocketNSS::~SSLServerSocketNSS() { } } -int SSLServerSocketNSS::Init() { - // Initialize the NSS SSL library in a threadsafe way. This also - // initializes the NSS base library. - EnsureNSSSSLInit(); - if (!NSS_IsInitialized()) - return ERR_UNEXPECTED; -#if !defined(OS_MACOSX) && !defined(OS_WIN) - // We must call EnsureOCSPInit() here, on the IO thread, to get the IO loop - // by MessageLoopForIO::current(). - // X509Certificate::Verify() runs on a worker thread of CertVerifier. - EnsureOCSPInit(); -#endif - - return OK; -} - int SSLServerSocketNSS::Accept(CompletionCallback* callback) { net_log_.BeginEvent(NetLog::TYPE_SSL_ACCEPT, NULL); @@ -183,27 +167,12 @@ int SSLServerSocketNSS::Write(IOBuffer* buf, int buf_len, return rv; } -// static -// NSS calls this if an incoming certificate needs to be verified. -// Do nothing but return SECSuccess. -// This is called only in full handshake mode. -// Peer certificate is retrieved in HandshakeCallback() later, which is called -// in full handshake mode or in resumption handshake mode. -SECStatus SSLServerSocketNSS::OwnAuthCertHandler(void* arg, - PRFileDesc* socket, - PRBool checksig, - PRBool is_server) { - // TODO(hclam): Implement. - // Tell NSS to not verify the certificate. - return SECSuccess; +bool SSLServerSocketNSS::SetReceiveBufferSize(int32 size) { + return false; } -// static -// NSS calls this when handshake is completed. -// After the SSL handshake is finished we need to verify the certificate. -void SSLServerSocketNSS::HandshakeCallback(PRFileDesc* socket, - void* arg) { - // TODO(hclam): Implement. +bool SSLServerSocketNSS::SetSendBufferSize(int32 size) { + return false; } int SSLServerSocketNSS::InitializeSSLOptions() { @@ -381,6 +350,47 @@ int SSLServerSocketNSS::InitializeSSLOptions() { return OK; } +void SSLServerSocketNSS::OnSendComplete(int result) { + if (next_handshake_state_ == STATE_HANDSHAKE) { + // In handshake phase. + OnHandshakeIOComplete(result); + return; + } + + if (!user_write_buf_ || !completed_handshake_) + return; + + int rv = DoWriteLoop(result); + if (rv != ERR_IO_PENDING) + DoWriteCallback(rv); +} + +void SSLServerSocketNSS::OnRecvComplete(int result) { + if (next_handshake_state_ == STATE_HANDSHAKE) { + // In handshake phase. + OnHandshakeIOComplete(result); + return; + } + + // Network layer received some data, check if client requested to read + // decrypted data. + if (!user_read_buf_ || !completed_handshake_) + return; + + int rv = DoReadLoop(result); + if (rv != ERR_IO_PENDING) + DoReadCallback(rv); +} + +void SSLServerSocketNSS::OnHandshakeIOComplete(int result) { + int rv = DoHandshakeLoop(result); + if (rv != ERR_IO_PENDING) { + net_log_.EndEvent(net::NetLog::TYPE_SSL_ACCEPT, NULL); + if (user_accept_callback_) + DoAcceptCallback(rv); + } +} + // Return 0 for EOF, // > 0 for bytes transferred immediately, // < 0 for error (or the non-error ERR_IO_PENDING). @@ -453,81 +463,6 @@ void SSLServerSocketNSS::BufferRecvComplete(int result) { OnRecvComplete(result); } -void SSLServerSocketNSS::OnSendComplete(int result) { - if (next_handshake_state_ == STATE_HANDSHAKE) { - // In handshake phase. - OnHandshakeIOComplete(result); - return; - } - - if (!user_write_buf_ || !completed_handshake_) - return; - - int rv = DoWriteLoop(result); - if (rv != ERR_IO_PENDING) - DoWriteCallback(rv); -} - -void SSLServerSocketNSS::OnRecvComplete(int result) { - if (next_handshake_state_ == STATE_HANDSHAKE) { - // In handshake phase. - OnHandshakeIOComplete(result); - return; - } - - // Network layer received some data, check if client requested to read - // decrypted data. - if (!user_read_buf_ || !completed_handshake_) - return; - - int rv = DoReadLoop(result); - if (rv != ERR_IO_PENDING) - DoReadCallback(rv); -} - -void SSLServerSocketNSS::OnHandshakeIOComplete(int result) { - int rv = DoHandshakeLoop(result); - if (rv != ERR_IO_PENDING) { - net_log_.EndEvent(net::NetLog::TYPE_SSL_ACCEPT, NULL); - if (user_accept_callback_) - DoAcceptCallback(rv); - } -} - -void SSLServerSocketNSS::DoAcceptCallback(int rv) { - DCHECK_NE(rv, ERR_IO_PENDING); - - CompletionCallback* c = user_accept_callback_; - user_accept_callback_ = NULL; - c->Run(rv > OK ? OK : rv); -} - -void SSLServerSocketNSS::DoReadCallback(int rv) { - DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_read_callback_); - - // Since Run may result in Read being called, clear |user_read_callback_| - // up front. - CompletionCallback* c = user_read_callback_; - user_read_callback_ = NULL; - user_read_buf_ = NULL; - user_read_buf_len_ = 0; - c->Run(rv); -} - -void SSLServerSocketNSS::DoWriteCallback(int rv) { - DCHECK(rv != ERR_IO_PENDING); - DCHECK(user_write_callback_); - - // Since Run may result in Write being called, clear |user_write_callback_| - // up front. - CompletionCallback* c = user_write_callback_; - user_write_callback_ = NULL; - user_write_buf_ = NULL; - user_write_buf_len_ = 0; - c->Run(rv); -} - // Do network I/O between the given buffer and the given socket. // Return true if some I/O performed, false otherwise (error or ERR_IO_PENDING) bool SSLServerSocketNSS::DoTransportIO() { @@ -674,4 +609,77 @@ int SSLServerSocketNSS::DoHandshake() { return net_error; } +void SSLServerSocketNSS::DoAcceptCallback(int rv) { + DCHECK_NE(rv, ERR_IO_PENDING); + + CompletionCallback* c = user_accept_callback_; + user_accept_callback_ = NULL; + c->Run(rv > OK ? OK : rv); +} + +void SSLServerSocketNSS::DoReadCallback(int rv) { + DCHECK(rv != ERR_IO_PENDING); + DCHECK(user_read_callback_); + + // Since Run may result in Read being called, clear |user_read_callback_| + // up front. + CompletionCallback* c = user_read_callback_; + user_read_callback_ = NULL; + user_read_buf_ = NULL; + user_read_buf_len_ = 0; + c->Run(rv); +} + +void SSLServerSocketNSS::DoWriteCallback(int rv) { + DCHECK(rv != ERR_IO_PENDING); + DCHECK(user_write_callback_); + + // Since Run may result in Write being called, clear |user_write_callback_| + // up front. + CompletionCallback* c = user_write_callback_; + user_write_callback_ = NULL; + user_write_buf_ = NULL; + user_write_buf_len_ = 0; + c->Run(rv); +} + +// static +// NSS calls this if an incoming certificate needs to be verified. +// Do nothing but return SECSuccess. +// This is called only in full handshake mode. +// Peer certificate is retrieved in HandshakeCallback() later, which is called +// in full handshake mode or in resumption handshake mode. +SECStatus SSLServerSocketNSS::OwnAuthCertHandler(void* arg, + PRFileDesc* socket, + PRBool checksig, + PRBool is_server) { + // TODO(hclam): Implement. + // Tell NSS to not verify the certificate. + return SECSuccess; +} + +// static +// NSS calls this when handshake is completed. +// After the SSL handshake is finished we need to verify the certificate. +void SSLServerSocketNSS::HandshakeCallback(PRFileDesc* socket, + void* arg) { + // TODO(hclam): Implement. +} + +int SSLServerSocketNSS::Init() { + // Initialize the NSS SSL library in a threadsafe way. This also + // initializes the NSS base library. + EnsureNSSSSLInit(); + if (!NSS_IsInitialized()) + return ERR_UNEXPECTED; +#if !defined(OS_MACOSX) && !defined(OS_WIN) + // We must call EnsureOCSPInit() here, on the IO thread, to get the IO loop + // by MessageLoopForIO::current(). + // X509Certificate::Verify() runs on a worker thread of CertVerifier. + EnsureOCSPInit(); +#endif + + return OK; +} + } // namespace net diff --git a/net/socket/ssl_server_socket_nss.h b/net/socket/ssl_server_socket_nss.h index 3883c9b..1289272 100644 --- a/net/socket/ssl_server_socket_nss.h +++ b/net/socket/ssl_server_socket_nss.h @@ -42,11 +42,14 @@ class SSLServerSocketNSS : public SSLServerSocket { CompletionCallback* callback); virtual int Write(IOBuffer* buf, int buf_len, CompletionCallback* callback); - virtual bool SetReceiveBufferSize(int32 size) { return false; } - virtual bool SetSendBufferSize(int32 size) { return false; } + virtual bool SetReceiveBufferSize(int32 size); + virtual bool SetSendBufferSize(int32 size); private: - virtual int Init(); + enum State { + STATE_NONE, + STATE_HANDSHAKE, + }; int InitializeSSLOptions(); @@ -59,8 +62,8 @@ class SSLServerSocketNSS : public SSLServerSocket { int BufferRecv(); void BufferRecvComplete(int result); bool DoTransportIO(); - int DoPayloadWrite(); int DoPayloadRead(); + int DoPayloadWrite(); int DoHandshakeLoop(int last_io_result); int DoReadLoop(int result); @@ -76,6 +79,8 @@ class SSLServerSocketNSS : public SSLServerSocket { PRBool is_server); static void HandshakeCallback(PRFileDesc* socket, void* arg); + virtual int Init(); + // Members used to send and receive buffer. CompletionCallbackImpl<SSLServerSocketNSS> buffer_send_callback_; CompletionCallbackImpl<SSLServerSocketNSS> buffer_recv_callback_; @@ -118,10 +123,6 @@ class SSLServerSocketNSS : public SSLServerSocket { // Private key used by the server. scoped_ptr<base::RSAPrivateKey> key_; - enum State { - STATE_NONE, - STATE_HANDSHAKE, - }; State next_handshake_state_; bool completed_handshake_; diff --git a/net/test/test_server.cc b/net/test/test_server.cc index 14da7f4..36ebf334 100644 --- a/net/test/test_server.cc +++ b/net/test/test_server.cc @@ -105,25 +105,6 @@ TestServer::~TestServer() { Stop(); } -void TestServer::Init(const FilePath& document_root) { - // At this point, the port that the testserver will listen on is unknown. - // The testserver will listen on an ephemeral port, and write the port - // number out over a pipe that this TestServer object will read from. Once - // that is complete, the host_port_pair_ will contain the actual port. - host_port_pair_ = HostPortPair(GetHostname(type_, https_options_), 0); - process_handle_ = base::kNullProcessHandle; - - FilePath src_dir; - PathService::Get(base::DIR_SOURCE_ROOT, &src_dir); - - document_root_ = src_dir.Append(document_root); - - certificates_dir_ = src_dir.Append(FILE_PATH_LITERAL("net")) - .Append(FILE_PATH_LITERAL("data")) - .Append(FILE_PATH_LITERAL("ssl")) - .Append(FILE_PATH_LITERAL("certificates")); -} - bool TestServer::Start() { if (type_ == TYPE_HTTPS) { if (!LoadTestRootCert()) @@ -276,6 +257,25 @@ bool TestServer::GetFilePathWithReplacements( return true; } +void TestServer::Init(const FilePath& document_root) { + // At this point, the port that the testserver will listen on is unknown. + // The testserver will listen on an ephemeral port, and write the port + // number out over a pipe that this TestServer object will read from. Once + // that is complete, the host_port_pair_ will contain the actual port. + host_port_pair_ = HostPortPair(GetHostname(type_, https_options_), 0); + process_handle_ = base::kNullProcessHandle; + + FilePath src_dir; + PathService::Get(base::DIR_SOURCE_ROOT, &src_dir); + + document_root_ = src_dir.Append(document_root); + + certificates_dir_ = src_dir.Append(FILE_PATH_LITERAL("net")) + .Append(FILE_PATH_LITERAL("data")) + .Append(FILE_PATH_LITERAL("ssl")) + .Append(FILE_PATH_LITERAL("certificates")); +} + bool TestServer::SetPythonPath() { FilePath third_party_dir; if (!PathService::Get(base::DIR_SOURCE_ROOT, &third_party_dir)) { @@ -307,6 +307,30 @@ bool TestServer::SetPythonPath() { return true; } +bool TestServer::ParseServerData(const std::string& server_data) { + VLOG(1) << "Server data: " << server_data; + base::JSONReader json_reader; + scoped_ptr<Value> value(json_reader.JsonToValue(server_data, true, false)); + if (!value.get() || + !value->IsType(Value::TYPE_DICTIONARY)) { + LOG(ERROR) << "Could not parse server data: " + << json_reader.GetErrorMessage(); + return false; + } + server_data_.reset(static_cast<DictionaryValue*>(value.release())); + int port = 0; + if (!server_data_->GetInteger("port", &port)) { + LOG(ERROR) << "Could not find port value"; + return false; + } + if ((port <= 0) || (port > kuint16max)) { + LOG(ERROR) << "Invalid port value: " << port; + return false; + } + host_port_pair_.set_port(port); + return true; +} + FilePath TestServer::GetRootCertificatePath() { return certificates_dir_.AppendASCII("root_ca_cert.crt"); } @@ -365,28 +389,4 @@ bool TestServer::AddCommandLineArguments(CommandLine* command_line) const { return true; } -bool TestServer::ParseServerData(const std::string& server_data) { - VLOG(1) << "Server data: " << server_data; - base::JSONReader json_reader; - scoped_ptr<Value> value(json_reader.JsonToValue(server_data, true, false)); - if (!value.get() || - !value->IsType(Value::TYPE_DICTIONARY)) { - LOG(ERROR) << "Could not parse server data: " - << json_reader.GetErrorMessage(); - return false; - } - server_data_.reset(static_cast<DictionaryValue*>(value.release())); - int port = 0; - if (!server_data_->GetInteger("port", &port)) { - LOG(ERROR) << "Could not find port value"; - return false; - } - if ((port <= 0) || (port > kuint16max)) { - LOG(ERROR) << "Invalid port value: " << port; - return false; - } - host_port_pair_.set_port(port); - return true; -} - } // namespace net diff --git a/net/url_request/url_request_about_job.cc b/net/url_request/url_request_about_job.cc index b8dab1a..f48e72a 100644 --- a/net/url_request/url_request_about_job.cc +++ b/net/url_request/url_request_about_job.cc @@ -12,16 +12,16 @@ namespace net { +URLRequestAboutJob::URLRequestAboutJob(URLRequest* request) + : URLRequestJob(request) { +} + // static URLRequestJob* URLRequestAboutJob::Factory(URLRequest* request, const std::string& scheme) { return new URLRequestAboutJob(request); } -URLRequestAboutJob::URLRequestAboutJob(URLRequest* request) - : URLRequestJob(request) { -} - void URLRequestAboutJob::Start() { // Start reading asynchronously so that all error reporting and data // callbacks happen as they would for network requests. diff --git a/net/url_request/url_request_about_job.h b/net/url_request/url_request_about_job.h index 7617208..4703830 100644 --- a/net/url_request/url_request_about_job.h +++ b/net/url_request/url_request_about_job.h @@ -17,13 +17,14 @@ class URLRequestAboutJob : public URLRequestJob { public: explicit URLRequestAboutJob(URLRequest* request); + static URLRequest::ProtocolFactory Factory; + + // URLRequestJob: virtual void Start(); virtual bool GetMimeType(std::string* mime_type) const; - static URLRequest::ProtocolFactory Factory; - private: - ~URLRequestAboutJob(); + virtual ~URLRequestAboutJob(); void StartAsync(); }; diff --git a/net/url_request/url_request_context.cc b/net/url_request/url_request_context.cc index 8e0f3bd..6eb075c 100644 --- a/net/url_request/url_request_context.cc +++ b/net/url_request/url_request_context.cc @@ -25,6 +25,10 @@ URLRequestContext::URLRequestContext() is_main_(false) { } +void URLRequestContext::set_cookie_store(CookieStore* cookie_store) { + cookie_store_ = cookie_store; +} + const std::string& URLRequestContext::GetUserAgent(const GURL& url) const { return EmptyString(); } @@ -32,8 +36,4 @@ const std::string& URLRequestContext::GetUserAgent(const GURL& url) const { URLRequestContext::~URLRequestContext() { } -void URLRequestContext::set_cookie_store(CookieStore* cookie_store) { - cookie_store_ = cookie_store; -} - } // namespace net diff --git a/net/url_request/url_request_file_dir_job.cc b/net/url_request/url_request_file_dir_job.cc index badb6b8..1cc15be 100644 --- a/net/url_request/url_request_file_dir_job.cc +++ b/net/url_request/url_request_file_dir_job.cc @@ -33,20 +33,6 @@ URLRequestFileDirJob::URLRequestFileDirJob(URLRequest* request, ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)) { } -URLRequestFileDirJob::~URLRequestFileDirJob() { - DCHECK(read_pending_ == false); - DCHECK(lister_ == NULL); -} - -void URLRequestFileDirJob::Start() { - // Start reading asynchronously so that all error reporting and data - // callbacks happen as they would for network requests. - MessageLoop::current()->PostTask( - FROM_HERE, - method_factory_.NewRunnableMethod( - &URLRequestFileDirJob::StartAsync)); -} - void URLRequestFileDirJob::StartAsync() { DCHECK(!lister_); @@ -63,6 +49,15 @@ void URLRequestFileDirJob::StartAsync() { NotifyHeadersComplete(); } +void URLRequestFileDirJob::Start() { + // Start reading asynchronously so that all error reporting and data + // callbacks happen as they would for network requests. + MessageLoop::current()->PostTask( + FROM_HERE, + method_factory_.NewRunnableMethod( + &URLRequestFileDirJob::StartAsync)); +} + void URLRequestFileDirJob::Kill() { if (canceled_) return; @@ -174,6 +169,11 @@ void URLRequestFileDirJob::OnListDone(int error) { Release(); // The Lister is finished; may delete *this* } +URLRequestFileDirJob::~URLRequestFileDirJob() { + DCHECK(read_pending_ == false); + DCHECK(lister_ == NULL); +} + void URLRequestFileDirJob::CloseLister() { if (lister_) { lister_->Cancel(); @@ -182,25 +182,6 @@ void URLRequestFileDirJob::CloseLister() { } } -bool URLRequestFileDirJob::FillReadBuffer(char *buf, int buf_size, - int *bytes_read) { - DCHECK(bytes_read); - - *bytes_read = 0; - - int count = std::min(buf_size, static_cast<int>(data_.size())); - if (count) { - memcpy(buf, &data_[0], count); - data_.erase(0, count); - *bytes_read = count; - return true; - } else if (list_complete_) { - // EOF - return true; - } - return false; -} - void URLRequestFileDirJob::CompleteRead() { if (read_pending_) { int bytes_read; @@ -221,4 +202,23 @@ void URLRequestFileDirJob::CompleteRead() { } } +bool URLRequestFileDirJob::FillReadBuffer(char *buf, int buf_size, + int *bytes_read) { + DCHECK(bytes_read); + + *bytes_read = 0; + + int count = std::min(buf_size, static_cast<int>(data_.size())); + if (count) { + memcpy(buf, &data_[0], count); + data_.erase(0, count); + *bytes_read = count; + return true; + } else if (list_complete_) { + // EOF + return true; + } + return false; +} + } // namespace net diff --git a/net/url_request/url_request_file_dir_job.h b/net/url_request/url_request_file_dir_job.h index 2b40a98..f938417 100644 --- a/net/url_request/url_request_file_dir_job.h +++ b/net/url_request/url_request_file_dir_job.h @@ -22,9 +22,12 @@ class URLRequestFileDirJob public: URLRequestFileDirJob(URLRequest* request, const FilePath& dir_path); + bool list_complete() const { return list_complete_; } + + virtual void StartAsync(); + // Overridden from URLRequestJob: virtual void Start(); - virtual void StartAsync(); virtual void Kill(); virtual bool ReadRawData(IOBuffer* buf, int buf_size, int *bytes_read); virtual bool GetMimeType(std::string* mime_type) const; @@ -35,12 +38,11 @@ class URLRequestFileDirJob const DirectoryLister::DirectoryListerData& data); virtual void OnListDone(int error); - bool list_complete() const { return list_complete_; } - private: virtual ~URLRequestFileDirJob(); void CloseLister(); + // When we have data and a read has been pending, this function // will fill the response buffer and notify the request // appropriately. diff --git a/net/url_request/url_request_file_job.cc b/net/url_request/url_request_file_job.cc index 7a1599a..0f4c423 100644 --- a/net/url_request/url_request_file_job.cc +++ b/net/url_request/url_request_file_job.cc @@ -82,6 +82,17 @@ class URLRequestFileJob::AsyncResolver }; #endif +URLRequestFileJob::URLRequestFileJob(URLRequest* request, + const FilePath& file_path) + : URLRequestJob(request), + file_path_(file_path), + ALLOW_THIS_IN_INITIALIZER_LIST( + io_callback_(this, &URLRequestFileJob::DidRead)), + is_directory_(false), + remaining_bytes_(0), + ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)) { +} + // static URLRequestJob* URLRequestFileJob::Factory(URLRequest* request, const std::string& scheme) { @@ -111,22 +122,33 @@ URLRequestJob* URLRequestFileJob::Factory(URLRequest* request, return new URLRequestFileJob(request, file_path); } -URLRequestFileJob::URLRequestFileJob(URLRequest* request, - const FilePath& file_path) - : URLRequestJob(request), - file_path_(file_path), - ALLOW_THIS_IN_INITIALIZER_LIST( - io_callback_(this, &URLRequestFileJob::DidRead)), - is_directory_(false), - remaining_bytes_(0), - ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)) { -} +#if defined(OS_CHROMEOS) +static const char* const kLocalAccessWhiteList[] = { + "/home/chronos/user/Downloads", + "/media", + "/mnt/partner_partition", + "/usr/share/chromeos-assets", + "/tmp", + "/var/log", +}; -URLRequestFileJob::~URLRequestFileJob() { -#if defined(OS_WIN) - DCHECK(!async_resolver_); -#endif +// static +bool URLRequestFileJob::AccessDisabled(const FilePath& file_path) { + if (URLRequest::IsFileAccessAllowed()) { // for tests. + return false; + } + + for (size_t i = 0; i < arraysize(kLocalAccessWhiteList); ++i) { + const FilePath white_listed_path(kLocalAccessWhiteList[i]); + // FilePath::operator== should probably handle trailing seperators. + if (white_listed_path == file_path.StripTrailingSeparators() || + white_listed_path.IsParent(file_path)) { + return false; + } + } + return true; } +#endif void URLRequestFileJob::Start() { #if defined(OS_WIN) @@ -204,6 +226,43 @@ bool URLRequestFileJob::ReadRawData(IOBuffer* dest, int dest_size, return false; } +bool URLRequestFileJob::IsRedirectResponse(GURL* location, + int* http_status_code) { + if (is_directory_) { + // This happens when we discovered the file is a directory, so needs a + // slash at the end of the path. + std::string new_path = request_->url().path(); + new_path.push_back('/'); + GURL::Replacements replacements; + replacements.SetPathStr(new_path); + + *location = request_->url().ReplaceComponents(replacements); + *http_status_code = 301; // simulate a permanent redirect + return true; + } + +#if defined(OS_WIN) + // Follow a Windows shortcut. + // We just resolve .lnk file, ignore others. + if (!LowerCaseEqualsASCII(file_path_.Extension(), ".lnk")) + return false; + + FilePath new_path = file_path_; + bool resolved; + resolved = file_util::ResolveShortcut(&new_path); + + // If shortcut is not resolved succesfully, do not redirect. + if (!resolved) + return false; + + *location = FilePathToFileURL(new_path); + *http_status_code = 301; + return true; +#else + return false; +#endif +} + bool URLRequestFileJob::GetContentEncodings( std::vector<Filter::FilterType>* encoding_types) { DCHECK(encoding_types->empty()); @@ -245,6 +304,12 @@ void URLRequestFileJob::SetExtraRequestHeaders( } } +URLRequestFileJob::~URLRequestFileJob() { +#if defined(OS_WIN) + DCHECK(!async_resolver_); +#endif +} + void URLRequestFileJob::DidResolve( bool exists, const base::PlatformFileInfo& file_info) { #if defined(OS_WIN) @@ -323,69 +388,4 @@ void URLRequestFileJob::DidRead(int result) { NotifyReadComplete(result); } -bool URLRequestFileJob::IsRedirectResponse(GURL* location, - int* http_status_code) { - if (is_directory_) { - // This happens when we discovered the file is a directory, so needs a - // slash at the end of the path. - std::string new_path = request_->url().path(); - new_path.push_back('/'); - GURL::Replacements replacements; - replacements.SetPathStr(new_path); - - *location = request_->url().ReplaceComponents(replacements); - *http_status_code = 301; // simulate a permanent redirect - return true; - } - -#if defined(OS_WIN) - // Follow a Windows shortcut. - // We just resolve .lnk file, ignore others. - if (!LowerCaseEqualsASCII(file_path_.Extension(), ".lnk")) - return false; - - FilePath new_path = file_path_; - bool resolved; - resolved = file_util::ResolveShortcut(&new_path); - - // If shortcut is not resolved succesfully, do not redirect. - if (!resolved) - return false; - - *location = FilePathToFileURL(new_path); - *http_status_code = 301; - return true; -#else - return false; -#endif -} - -#if defined(OS_CHROMEOS) -static const char* const kLocalAccessWhiteList[] = { - "/home/chronos/user/Downloads", - "/media", - "/mnt/partner_partition", - "/usr/share/chromeos-assets", - "/tmp", - "/var/log", -}; - -// static -bool URLRequestFileJob::AccessDisabled(const FilePath& file_path) { - if (URLRequest::IsFileAccessAllowed()) { // for tests. - return false; - } - - for (size_t i = 0; i < arraysize(kLocalAccessWhiteList); ++i) { - const FilePath white_listed_path(kLocalAccessWhiteList[i]); - // FilePath::operator== should probably handle trailing seperators. - if (white_listed_path == file_path.StripTrailingSeparators() || - white_listed_path.IsParent(file_path)) { - return false; - } - } - return true; -} -#endif - } // namespace net diff --git a/net/url_request/url_request_file_job.h b/net/url_request/url_request_file_job.h index 1a09b04..4dbcb0b 100644 --- a/net/url_request/url_request_file_job.h +++ b/net/url_request/url_request_file_job.h @@ -28,6 +28,13 @@ class URLRequestFileJob : public URLRequestJob { public: URLRequestFileJob(URLRequest* request, const FilePath& file_path); + static URLRequest::ProtocolFactory Factory; + +#if defined(OS_CHROMEOS) + static bool AccessDisabled(const FilePath& file_path); +#endif + + // URLRequestJob: virtual void Start(); virtual void Kill(); virtual bool ReadRawData(IOBuffer* buf, int buf_size, int* bytes_read); @@ -37,12 +44,6 @@ class URLRequestFileJob : public URLRequestJob { virtual bool GetMimeType(std::string* mime_type) const; virtual void SetExtraRequestHeaders(const HttpRequestHeaders& headers); - static URLRequest::ProtocolFactory Factory; - -#if defined(OS_CHROMEOS) - static bool AccessDisabled(const FilePath& file_path); -#endif - protected: virtual ~URLRequestFileJob(); diff --git a/net/url_request/url_request_job_manager.cc b/net/url_request/url_request_job_manager.cc index f311cc4..5fd7be6 100644 --- a/net/url_request/url_request_job_manager.cc +++ b/net/url_request/url_request_job_manager.cc @@ -39,15 +39,6 @@ static const SchemeToFactory kBuiltinFactories[] = { { "data", URLRequestDataJob::Factory }, }; -URLRequestJobManager::URLRequestJobManager() : enable_file_access_(false) { -#ifndef NDEBUG - allowed_thread_ = 0; - allowed_thread_initialized_ = false; -#endif -} - -URLRequestJobManager::~URLRequestJobManager() {} - // static URLRequestJobManager* URLRequestJobManager::GetInstance() { return Singleton<URLRequestJobManager>::get(); @@ -215,4 +206,13 @@ void URLRequestJobManager::UnregisterRequestInterceptor( interceptors_.erase(i); } +URLRequestJobManager::URLRequestJobManager() : enable_file_access_(false) { +#ifndef NDEBUG + allowed_thread_ = 0; + allowed_thread_initialized_ = false; +#endif +} + +URLRequestJobManager::~URLRequestJobManager() {} + } // namespace net diff --git a/net/url_request/url_request_job_manager.h b/net/url_request/url_request_job_manager.h index e4efcf5..ca9ada9 100644 --- a/net/url_request/url_request_job_manager.h +++ b/net/url_request/url_request_job_manager.h @@ -76,17 +76,7 @@ class URLRequestJobManager { URLRequestJobManager(); ~URLRequestJobManager(); - mutable base::Lock lock_; - FactoryMap factories_; - InterceptorList interceptors_; - bool enable_file_access_; - #ifndef NDEBUG - // We use this to assert that CreateJob and the registration functions all - // run on the same thread. - mutable base::PlatformThreadId allowed_thread_; - mutable bool allowed_thread_initialized_; - // The first guy to call this function sets the allowed thread. This way we // avoid needing to define that thread externally. Since we expect all // callers to be on the same thread, we don't worry about threads racing to @@ -110,8 +100,18 @@ class URLRequestJobManager { return true; #endif } + + // We use this to assert that CreateJob and the registration functions all + // run on the same thread. + mutable base::PlatformThreadId allowed_thread_; + mutable bool allowed_thread_initialized_; #endif + mutable base::Lock lock_; + FactoryMap factories_; + InterceptorList interceptors_; + bool enable_file_access_; + DISALLOW_COPY_AND_ASSIGN(URLRequestJobManager); }; |