diff options
author | noelutz@google.com <noelutz@google.com@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-06-17 17:54:50 +0000 |
---|---|---|
committer | noelutz@google.com <noelutz@google.com@0039d316-1c4b-4281-b951-d872f2087c98> | 2011-06-17 17:54:50 +0000 |
commit | cae64fedec3de095c3053d7f7083454675bfdcdd (patch) | |
tree | fa36ec2121b0bc67bcad46054a0cdb1a6ddd3ce6 /chrome/browser/safe_browsing | |
parent | c0929506d6d653ce7ed7c14ea5c747b2abcc9f15 (diff) | |
download | chromium_src-cae64fedec3de095c3053d7f7083454675bfdcdd.zip chromium_src-cae64fedec3de095c3053d7f7083454675bfdcdd.tar.gz chromium_src-cae64fedec3de095c3053d7f7083454675bfdcdd.tar.bz2 |
Initial CL to update the client model more frequently. This CL also changes the way we send the model from the browser to the renderer. Instead of sending a file descriptor we send the actual model over an IPC.
BUG=None
TEST=None
Committed: http://src.chromium.org/viewvc/chrome?view=rev&revision=89098
Review URL: http://codereview.chromium.org/7057025
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@89509 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'chrome/browser/safe_browsing')
4 files changed, 277 insertions, 246 deletions
diff --git a/chrome/browser/safe_browsing/client_side_detection_host_unittest.cc b/chrome/browser/safe_browsing/client_side_detection_host_unittest.cc index 5a19583..c490bd1 100644 --- a/chrome/browser/safe_browsing/client_side_detection_host_unittest.cc +++ b/chrome/browser/safe_browsing/client_side_detection_host_unittest.cc @@ -5,7 +5,6 @@ #include "base/file_path.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" -#include "base/scoped_temp_dir.h" #include "base/task.h" #include "chrome/browser/safe_browsing/client_side_detection_host.h" #include "chrome/browser/safe_browsing/client_side_detection_service.h" @@ -44,15 +43,14 @@ const bool kTrue = true; } namespace safe_browsing { - +namespace { MATCHER_P(EqualsProto, other, "") { return other.SerializeAsString() == arg.SerializeAsString(); } class MockClientSideDetectionService : public ClientSideDetectionService { public: - explicit MockClientSideDetectionService(const FilePath& model_path) - : ClientSideDetectionService(model_path, NULL) {} + MockClientSideDetectionService() : ClientSideDetectionService(NULL) {} virtual ~MockClientSideDetectionService() {}; MOCK_METHOD2(SendClientReportPhishingRequest, @@ -104,6 +102,7 @@ void QuitUIMessageLoop() { FROM_HERE, new MessageLoop::QuitTask()); } +} // namespace class ClientSideDetectionHostTest : public TabContentsWrapperTestHarness { public: @@ -122,12 +121,7 @@ class ClientSideDetectionHostTest : public TabContentsWrapperTestHarness { ASSERT_TRUE(io_thread_->Start()); // Inject service classes. - ScopedTempDir tmp_dir; - ASSERT_TRUE(tmp_dir.CreateUniqueTempDir()); - FilePath model_path = tmp_dir.path().AppendASCII("model"); - - csd_service_.reset(new StrictMock<MockClientSideDetectionService>( - model_path)); + csd_service_.reset(new StrictMock<MockClientSideDetectionService>()); sb_service_ = new StrictMock<MockSafeBrowsingService>(); csd_host_ = contents_wrapper()->safebrowsing_detection_host(); csd_host_->set_client_side_detection_service(csd_service_.get()); diff --git a/chrome/browser/safe_browsing/client_side_detection_service.cc b/chrome/browser/safe_browsing/client_side_detection_service.cc index 7122bd6..c3cd4a0 100644 --- a/chrome/browser/safe_browsing/client_side_detection_service.cc +++ b/chrome/browser/safe_browsing/client_side_detection_service.cc @@ -5,17 +5,17 @@ #include "chrome/browser/safe_browsing/client_side_detection_service.h" #include "base/command_line.h" -#include "base/file_path.h" #include "base/file_util_proxy.h" #include "base/logging.h" +#include "base/time.h" #include "base/memory/scoped_ptr.h" #include "base/message_loop.h" #include "base/metrics/histogram.h" -#include "base/platform_file.h" #include "base/stl_util-inl.h" #include "base/task.h" #include "base/time.h" #include "chrome/common/net/http_return.h" +#include "chrome/common/safe_browsing/client_model.pb.h" #include "chrome/common/safe_browsing/csd.pb.h" #include "chrome/common/safe_browsing/safebrowsing_messages.h" #include "content/browser/browser_thread.h" @@ -23,18 +23,19 @@ #include "content/common/notification_service.h" #include "content/common/url_fetcher.h" #include "googleurl/src/gurl.h" -#include "ipc/ipc_platform_file.h" #include "net/base/load_flags.h" +#include "net/http/http_response_headers.h" #include "net/url_request/url_request_context_getter.h" #include "net/url_request/url_request_status.h" -#if defined(OS_MACOSX) -#include "base/mac/mac_util.h" -#endif - namespace safe_browsing { +const size_t ClientSideDetectionService::kMaxModelSizeBytes = 90 * 1024; const int ClientSideDetectionService::kMaxReportsPerInterval = 3; +// TODO(noelutz): once we know this mechanism works as intended we should fetch +// the model much more frequently. E.g., every 5 minutes or so. +const int ClientSideDetectionService::kClientModelFetchIntervalMs = 3600 * 1000; +const int ClientSideDetectionService::kInitialClientModelFetchDelayMs = 10000; const base::TimeDelta ClientSideDetectionService::kReportsInterval = base::TimeDelta::FromDays(1); @@ -48,11 +49,8 @@ const char ClientSideDetectionService::kClientReportPhishingUrl[] = // Note: when updatng the model version, don't forget to change the filename // in chrome/common/chrome_constants.cc as well, or else existing users won't // download the new model. -// -// TODO(bryner): add version metadata so that clients can download new models -// without needing a new model filename. const char ClientSideDetectionService::kClientModelUrl[] = - "https://ssl.gstatic.com/safebrowsing/csd/client_model_v1.pb"; + "https://ssl.gstatic.com/safebrowsing/csd/client_model_v2.pb"; struct ClientSideDetectionService::ClientReportInfo { scoped_ptr<ClientReportPhishingRequestCallback> callback; @@ -64,13 +62,10 @@ ClientSideDetectionService::CacheState::CacheState(bool phish, base::Time time) timestamp(time) {} ClientSideDetectionService::ClientSideDetectionService( - const FilePath& model_path, net::URLRequestContextGetter* request_context_getter) - : model_path_(model_path), - model_status_(UNKNOWN_STATUS), - model_file_(base::kInvalidPlatformFileValue), + : model_version_(-1), + tmp_model_version_(-1), ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), - ALLOW_THIS_IN_INITIALIZER_LIST(callback_factory_(this)), request_context_getter_(request_context_getter) { registrar_.Add(this, NotificationType::RENDERER_PROCESS_CREATED, NotificationService::AllSources()); @@ -81,40 +76,39 @@ ClientSideDetectionService::~ClientSideDetectionService() { STLDeleteContainerPairPointers(client_phishing_reports_.begin(), client_phishing_reports_.end()); client_phishing_reports_.clear(); - CloseModelFile(); } /* static */ ClientSideDetectionService* ClientSideDetectionService::Create( - const FilePath& model_path, + const FilePath& model_dir, net::URLRequestContextGetter* request_context_getter) { DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); scoped_ptr<ClientSideDetectionService> service( - new ClientSideDetectionService(model_path, request_context_getter)); + new ClientSideDetectionService(request_context_getter)); if (!service->InitializePrivateNetworks()) { UMA_HISTOGRAM_COUNTS("SBClientPhishing.InitPrivateNetworksFailed", 1); return NULL; } + // We fetch the model at every browser restart. In a lot of cases the model + // will be in the cache so it won't actually be fetched from the network. + // We delay the first model fetch to avoid slowing down browser startup. + MessageLoop::current()->PostDelayedTask( + FROM_HERE, + service->method_factory_.NewRunnableMethod( + &ClientSideDetectionService::StartFetchModel), + kInitialClientModelFetchDelayMs); - // We try to open the model file right away and start fetching it if - // it does not already exist on disk. - base::FileUtilProxy::CreateOrOpenCallback* cb = - service.get()->callback_factory_.NewCallback( - &ClientSideDetectionService::OpenModelFileDone); - if (!base::FileUtilProxy::CreateOrOpen( - BrowserThread::GetMessageLoopProxyForThread(BrowserThread::FILE), - model_path, - base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ, - cb)) { - delete cb; - return NULL; - } - - // Delete the previous-version model file. - // TODO(bryner): Remove this for M14. + // Delete the previous-version model files. + // TODO(bryner): Remove this for M15 (including the model_dir argument to + // Create()). base::FileUtilProxy::Delete( BrowserThread::GetMessageLoopProxyForThread(BrowserThread::FILE), - model_path.DirName().AppendASCII("Safe Browsing Phishing Model"), + model_dir.AppendASCII("Safe Browsing Phishing Model"), + false /* not recursive */, + NULL /* not interested in result */); + base::FileUtilProxy::Delete( + BrowserThread::GetMessageLoopProxyForThread(BrowserThread::FILE), + model_dir.AppendASCII("Safe Browsing Phishing Model v1"), false /* not recursive */, NULL /* not interested in result */); return service.release(); @@ -170,122 +164,72 @@ void ClientSideDetectionService::OnURLFetchComplete( void ClientSideDetectionService::Observe(NotificationType type, const NotificationSource& source, const NotificationDetails& details) { + DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); DCHECK(type == NotificationType::RENDERER_PROCESS_CREATED); - if (model_status_ == UNKNOWN_STATUS) { - // The model isn't ready. When it's known, we'll call all renderers. + if (model_version_ < 0) { + // Model might not be ready or maybe there was an error. return; } - - RenderProcessHost* process = Source<RenderProcessHost>(source).ptr(); - SendModelToProcess(process); + SendModelToProcess(Source<RenderProcessHost>(source).ptr()); } void ClientSideDetectionService::SendModelToProcess( RenderProcessHost* process) { - if (model_file_ == base::kInvalidPlatformFileValue) - return; - - IPC::PlatformFileForTransit file; -#if defined(OS_POSIX) - file = base::FileDescriptor(model_file_, false); -#elif defined(OS_WIN) - ::DuplicateHandle(::GetCurrentProcess(), model_file_, process->GetHandle(), - &file, 0, false, DUPLICATE_SAME_ACCESS); -#endif VLOG(2) << "Sending phishing model to renderer"; - process->Send(new SafeBrowsingMsg_SetPhishingModel(file)); + process->Send(new SafeBrowsingMsg_SetPhishingModel(model_str_)); } -void ClientSideDetectionService::SetModelStatus(ModelStatus status) { - DCHECK_NE(READY_STATUS, model_status_); - model_status_ = status; - +void ClientSideDetectionService::SendModelToRenderers() { for (RenderProcessHost::iterator i(RenderProcessHost::AllHostsIterator()); !i.IsAtEnd(); i.Advance()) { - RenderProcessHost* process = i.GetCurrentValue(); - if (process->GetHandle()) - SendModelToProcess(process); + SendModelToProcess(i.GetCurrentValue()); } } -void ClientSideDetectionService::OpenModelFileDone( - base::PlatformFileError error_code, - base::PassPlatformFile file, - bool created) { - DCHECK(!created); - if (base::PLATFORM_FILE_OK == error_code) { - // The model file already exists. There is no need to fetch the model. - model_file_ = file.ReleaseValue(); - SetModelStatus(READY_STATUS); -#if defined(OS_MACOSX) - base::mac::SetFileBackupExclusion(model_path_); -#endif - } else if (base::PLATFORM_FILE_ERROR_NOT_FOUND == error_code) { - // We need to fetch the model since it does not exist yet. - model_fetcher_.reset(URLFetcher::Create(0 /* ID is not used */, - GURL(kClientModelUrl), - URLFetcher::GET, - this)); - model_fetcher_->set_request_context(request_context_getter_.get()); - model_fetcher_->Start(); - } else { - // It is not clear what we should do in this case. For now we simply fail. - // Hopefully, we'll be able to read the model during the next browser - // restart. - SetModelStatus(ERROR_STATUS); - } +void ClientSideDetectionService::StartFetchModel() { + // Start fetching the model either from the cache or possibly from the + // network if the model isn't in the cache. + model_fetcher_.reset(URLFetcher::Create(0 /* ID is not used */, + GURL(kClientModelUrl), + URLFetcher::GET, + this)); + model_fetcher_->set_request_context(request_context_getter_.get()); + model_fetcher_->Start(); } -void ClientSideDetectionService::CreateModelFileDone( - base::PlatformFileError error_code, - base::PassPlatformFile file, - bool created) { - model_file_ = file.ReleaseValue(); - base::FileUtilProxy::WriteCallback* cb = callback_factory_.NewCallback( - &ClientSideDetectionService::WriteModelFileDone); - if (!created || - base::PLATFORM_FILE_OK != error_code || - !base::FileUtilProxy::Write( - BrowserThread::GetMessageLoopProxyForThread(BrowserThread::FILE), - model_file_, - 0 /* offset */, tmp_model_string_->data(), tmp_model_string_->size(), - cb)) { - delete cb; - // An error occurred somewhere. We close the model file if necessary and - // then run all the pending callbacks giving them an invalid model file. - CloseModelFile(); - SetModelStatus(ERROR_STATUS); -#if defined(OS_MACOSX) - } else { - base::mac::SetFileBackupExclusion(model_path_); -#endif +void ClientSideDetectionService::EndFetchModel(ClientModelStatus status) { + UMA_HISTOGRAM_ENUMERATION("SBClientPhishing.ClientModelStatus", + status, + MODEL_STATUS_MAX); + // If there is already a valid model but we're unable to reload one + // we leave the old model. + if (status == MODEL_SUCCESS) { + // Replace the model string and version; + model_str_.swap(tmp_model_str_); + model_version_ = tmp_model_version_; + SendModelToRenderers(); } -} - -void ClientSideDetectionService::WriteModelFileDone( - base::PlatformFileError error_code, - int bytes_written) { - if (base::PLATFORM_FILE_OK == error_code) { - SetModelStatus(READY_STATUS); - } else { - // TODO(noelutz): maybe we should retry writing the model since we - // did already fetch the model? - CloseModelFile(); - SetModelStatus(ERROR_STATUS); + tmp_model_str_.clear(); + tmp_model_version_ = -1; + + int delay_ms = kClientModelFetchIntervalMs; + // If the most recently fetched model had a valid max-age and the model was + // valid we're scheduling the next model update for after the max-age expired. + if (tmp_model_max_age_.get() && + (status == MODEL_SUCCESS || status == MODEL_NOT_CHANGED)) { + // We're adding 60s of additional delay to make sure we're past + // the model's age. + *tmp_model_max_age_ += base::TimeDelta::FromMinutes(1); + delay_ms = tmp_model_max_age_->InMilliseconds(); } - // Delete the model string that we kept around while we were writing the - // string to disk - we don't need it anymore. - tmp_model_string_.reset(); -} + tmp_model_max_age_.reset(); -void ClientSideDetectionService::CloseModelFile() { - if (model_file_ != base::kInvalidPlatformFileValue) { - base::FileUtilProxy::Close( - BrowserThread::GetMessageLoopProxyForThread(BrowserThread::FILE), - model_file_, - NULL); - } - model_file_ = base::kInvalidPlatformFileValue; + // Schedule the next model reload. + MessageLoop::current()->PostDelayedTask( + FROM_HERE, + method_factory_.NewRunnableMethod( + &ClientSideDetectionService::StartFetchModel), + delay_ms); } void ClientSideDetectionService::StartClientReportPhishingRequest( @@ -330,28 +274,35 @@ void ClientSideDetectionService::HandleModelResponse( int response_code, const net::ResponseCookies& cookies, const std::string& data) { - if (status.is_success() && RC_REQUEST_OK == response_code) { - // Copy the model because it has to be accessible after this function - // returns. Once we have written the model to a file we will delete the - // temporary model string. TODO(noelutz): don't store the model to disk if - // it's invalid. - tmp_model_string_.reset(new std::string(data)); - base::FileUtilProxy::CreateOrOpenCallback* cb = - callback_factory_.NewCallback( - &ClientSideDetectionService::CreateModelFileDone); - if (!base::FileUtilProxy::CreateOrOpen( - BrowserThread::GetMessageLoopProxyForThread(BrowserThread::FILE), - model_path_, - base::PLATFORM_FILE_CREATE_ALWAYS | - base::PLATFORM_FILE_WRITE | - base::PLATFORM_FILE_READ, - cb)) { - delete cb; - SetModelStatus(ERROR_STATUS); - } + base::TimeDelta max_age; + if (status.is_success() && RC_REQUEST_OK == response_code && + source->response_headers() && + source->response_headers()->GetMaxAgeValue(&max_age)) { + tmp_model_max_age_.reset(new base::TimeDelta(max_age)); + } + ClientSideModel model; + ClientModelStatus model_status; + if (!status.is_success() || RC_REQUEST_OK != response_code) { + model_status = MODEL_FETCH_FAILED; + } else if (data.empty()) { + model_status = MODEL_EMPTY; + } else if (data.size() > kMaxModelSizeBytes) { + model_status = MODEL_TOO_LARGE; + } else if (!model.ParseFromString(data)) { + model_status = MODEL_PARSE_ERROR; + } else if (!model.IsInitialized() || !model.has_version()) { + model_status = MODEL_MISSING_FIELDS; + } else if (model.version() < 0 || + (model_version_ > 0 && model.version() < model_version_)) { + model_status = MODEL_INVALID_VERSION_NUMBER; + } else if (model.version() == model_version_) { + model_status = MODEL_NOT_CHANGED; } else { - SetModelStatus(ERROR_STATUS); + tmp_model_version_ = model.version(); + tmp_model_str_.assign(data); + model_status = MODEL_SUCCESS; } + EndFetchModel(model_status); } void ClientSideDetectionService::HandlePhishingVerdict( @@ -470,5 +421,4 @@ bool ClientSideDetectionService::InitializePrivateNetworks() { } return true; } - } // namespace safe_browsing diff --git a/chrome/browser/safe_browsing/client_side_detection_service.h b/chrome/browser/safe_browsing/client_side_detection_service.h index 4bf515f..39973a1 100644 --- a/chrome/browser/safe_browsing/client_side_detection_service.h +++ b/chrome/browser/safe_browsing/client_side_detection_service.h @@ -3,13 +3,12 @@ // found in the LICENSE file. // // Helper class which handles communication with the SafeBrowsing backends for -// client-side phishing detection. This class can be used to get a file -// descriptor to the client-side phishing model and also to send a ping back to -// Google to verify if a particular site is really phishing or not. +// client-side phishing detection. This class is used to fetch the client-side +// model and send it to all renderers. This class is also used to send a ping +// back to Google to verify if a particular site is really phishing or not. // // This class is not thread-safe and expects all calls to be made on the UI -// thread. We also expect that the calling thread runs a message loop and that -// there is a FILE thread running to execute asynchronous file operations. +// thread. We also expect that the calling thread runs a message loop. #ifndef CHROME_BROWSER_SAFE_BROWSING_CLIENT_SIDE_DETECTION_SERVICE_H_ #define CHROME_BROWSER_SAFE_BROWSING_CLIENT_SIDE_DETECTION_SERVICE_H_ @@ -22,14 +21,11 @@ #include <vector> #include "base/basictypes.h" -#include "base/callback.h" -#include "base/file_path.h" +#include "base/callback_old.h" #include "base/gtest_prod_util.h" #include "base/memory/linked_ptr.h" #include "base/memory/ref_counted.h" -#include "base/memory/scoped_callback_factory.h" #include "base/memory/scoped_ptr.h" -#include "base/platform_file.h" #include "base/task.h" #include "base/time.h" #include "content/common/notification_observer.h" @@ -40,6 +36,10 @@ class RenderProcessHost; +namespace base { +class TimeDelta; +} + namespace net { class URLRequestContextGetter; class URLRequestStatus; @@ -57,10 +57,10 @@ class ClientSideDetectionService : public URLFetcher::Delegate, virtual ~ClientSideDetectionService(); // Creates a client-side detection service and starts fetching the client-side - // detection model if necessary. The model will be stored in |model_path|. - // The caller takes ownership of the object. This function may return NULL. + // detection model if necessary. The caller takes ownership of the object. + // This function may return NULL. static ClientSideDetectionService* Create( - const FilePath& model_path, + const FilePath& model_dir, net::URLRequestContextGetter* request_context_getter); // From the URLFetcher::Delegate interface. @@ -110,21 +110,35 @@ class ClientSideDetectionService : public URLFetcher::Delegate, protected: // Use Create() method to create an instance of this object. - ClientSideDetectionService( - const FilePath& model_path, + explicit ClientSideDetectionService( net::URLRequestContextGetter* request_context_getter); + // Enum used to keep stats about why we fail to get the client model. + enum ClientModelStatus { + MODEL_SUCCESS, + MODEL_NOT_CHANGED, + MODEL_FETCH_FAILED, + MODEL_EMPTY, + MODEL_TOO_LARGE, + MODEL_PARSE_ERROR, + MODEL_MISSING_FIELDS, + MODEL_INVALID_VERSION_NUMBER, + MODEL_STATUS_MAX // Always add new values before this one. + }; + + // Starts fetching the model from the network or the cache. This method + // is called periodically to check whether a new client model is available + // for download. + void StartFetchModel(); + + // This method is called when we're done fetching the model either because + // we hit an error somewhere or because we're actually done fetch and + // validating the model. + virtual void EndFetchModel(ClientModelStatus status); // Virtual for testing. + private: friend class ClientSideDetectionServiceTest; - - enum ModelStatus { - // It's unclear whether or not the model was already fetched. - UNKNOWN_STATUS, - // Model is fetched and is stored on disk. - READY_STATUS, - // Error occured during fetching or writing. - ERROR_STATUS, - }; + FRIEND_TEST_ALL_PREFIXES(ClientSideDetectionServiceTest, FetchModelTest); // CacheState holds all information necessary to respond to a caller without // actually making a HTTP request. @@ -142,41 +156,14 @@ class ClientSideDetectionService : public URLFetcher::Delegate, static const char kClientReportPhishingUrl[]; static const char kClientModelUrl[]; + static const size_t kMaxModelSizeBytes; static const int kMaxReportsPerInterval; + static const int kClientModelFetchIntervalMs; + static const int kInitialClientModelFetchDelayMs; static const base::TimeDelta kReportsInterval; static const base::TimeDelta kNegativeCacheInterval; static const base::TimeDelta kPositiveCacheInterval; - // Sets the model status and invokes all the pending callbacks in - // |open_callbacks_| with the current |model_file_| as parameter. - void SetModelStatus(ModelStatus status); - - // Called once the initial open() of the model file is done. If the file - // exists we're done and we can call all the pending callbacks. If the - // file doesn't exist this method will asynchronously fetch the model - // from the server by invoking StartFetchingModel(). - void OpenModelFileDone(base::PlatformFileError error_code, - base::PassPlatformFile file, - bool created); - - // Callback that is invoked once the attempt to create the model - // file on disk is done. If the file was created successfully we - // start writing the model to disk (asynchronously). Otherwise, we - // give up and send an invalid platform file to all the pending callbacks. - void CreateModelFileDone(base::PlatformFileError error_code, - base::PassPlatformFile file, - bool created); - - // Callback is invoked once we're done writing the model file to disk. - // If everything went well then |model_file_| is a valid platform file - // that can be sent to all the pending callbacks. If an error occurs - // we give up and send an invalid platform file to all the pending callbacks. - void WriteModelFileDone(base::PlatformFileError error_code, - int bytes_written); - - // Helper function which closes the |model_file_| if necessary. - void CloseModelFile(); - // Starts sending the request to the client-side detection frontends. // This method takes ownership of both pointers. void StartClientReportPhishingRequest( @@ -214,11 +201,16 @@ class ClientSideDetectionService : public URLFetcher::Delegate, // Send the model to the given renderer. void SendModelToProcess(RenderProcessHost* process); - FilePath model_path_; - ModelStatus model_status_; - base::PlatformFile model_file_; + // Same as above but sends the model to all rendereres. + void SendModelToRenderers(); + + std::string model_str_; + int model_version_; scoped_ptr<URLFetcher> model_fetcher_; - scoped_ptr<std::string> tmp_model_string_; + + std::string tmp_model_str_; + int tmp_model_version_; + scoped_ptr<base::TimeDelta> tmp_model_max_age_; // Map of client report phishing request to the corresponding callback that // has to be invoked when the request is done. @@ -242,12 +234,6 @@ class ClientSideDetectionService : public URLFetcher::Delegate, // SendClientReportPhishingRequest. ScopedRunnableMethodFactory<ClientSideDetectionService> method_factory_; - // The client-side detection service object (this) might go away before some - // of the callbacks are done (e.g., asynchronous file operations). The - // callback factory will revoke all pending callbacks if this goes away to - // avoid a crash. - base::ScopedCallbackFactory<ClientSideDetectionService> callback_factory_; - // The context we use to issue network requests. scoped_refptr<net::URLRequestContextGetter> request_context_getter_; diff --git a/chrome/browser/safe_browsing/client_side_detection_service_unittest.cc b/chrome/browser/safe_browsing/client_side_detection_service_unittest.cc index 3b5d37f..5e2897d 100644 --- a/chrome/browser/safe_browsing/client_side_detection_service_unittest.cc +++ b/chrome/browser/safe_browsing/client_side_detection_service_unittest.cc @@ -8,25 +8,42 @@ #include "base/callback.h" #include "base/file_path.h" -#include "base/file_util.h" -#include "base/file_util_proxy.h" #include "base/logging.h" #include "base/memory/scoped_ptr.h" #include "base/message_loop.h" -#include "base/platform_file.h" #include "base/scoped_temp_dir.h" #include "base/task.h" #include "base/time.h" #include "chrome/browser/safe_browsing/client_side_detection_service.h" +#include "chrome/common/safe_browsing/client_model.pb.h" #include "chrome/common/safe_browsing/csd.pb.h" #include "content/browser/browser_thread.h" #include "content/common/test_url_fetcher_factory.h" #include "content/common/url_fetcher.h" #include "googleurl/src/gurl.h" #include "net/url_request/url_request_status.h" +#include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" +using ::testing::Mock; + namespace safe_browsing { +namespace { +class MockClientSideDetectionService : public ClientSideDetectionService { + public: + MockClientSideDetectionService() : ClientSideDetectionService(NULL) {} + virtual ~MockClientSideDetectionService() {} + + MOCK_METHOD1(EndFetchModel, void(ClientModelStatus)); + + private: + DISALLOW_COPY_AND_ASSIGN(MockClientSideDetectionService); +}; + +ACTION(QuitCurrentMessageLoop) { + MessageLoop::current()->Quit(); +} +} // namespace class ClientSideDetectionServiceTest : public testing::Test { protected: @@ -47,13 +64,6 @@ class ClientSideDetectionServiceTest : public testing::Test { browser_thread_.reset(); } - std::string ReadModelFile(base::PlatformFile model_file) { - char buf[1024]; - int n = base::ReadPlatformFile(model_file, 0, buf, 1024); - EXPECT_LE(0, n); - return (n < 0) ? "" : std::string(buf, n); - } - bool SendClientReportPhishingRequest(const GURL& phishing_url, float score) { ClientPhishingRequest* request = new ClientPhishingRequest(); @@ -159,12 +169,107 @@ class ClientSideDetectionServiceTest : public testing::Test { bool is_phishing_; }; +TEST_F(ClientSideDetectionServiceTest, FetchModelTest) { + // We don't want to use a real service class here because we can't call + // the real EndFetchModel. It would reschedule a reload which might + // make the test flaky. + MockClientSideDetectionService service; + + // The model fetch failed. + SetModelFetchResponse("blamodel", false /* failure */); + EXPECT_CALL(service, EndFetchModel( + ClientSideDetectionService::MODEL_FETCH_FAILED)) + .WillOnce(QuitCurrentMessageLoop()); + service.StartFetchModel(); + msg_loop_.Run(); // EndFetchModel will quit the message loop. + Mock::VerifyAndClearExpectations(&service); + + // Empty model file. + SetModelFetchResponse("", true /* success */); + EXPECT_CALL(service, EndFetchModel( + ClientSideDetectionService::MODEL_EMPTY)) + .WillOnce(QuitCurrentMessageLoop()); + service.StartFetchModel(); + msg_loop_.Run(); // EndFetchModel will quit the message loop. + Mock::VerifyAndClearExpectations(&service); + + // Model is too large. + SetModelFetchResponse( + std::string(ClientSideDetectionService::kMaxModelSizeBytes + 1, 'x'), + true /* success */); + EXPECT_CALL(service, EndFetchModel( + ClientSideDetectionService::MODEL_TOO_LARGE)) + .WillOnce(QuitCurrentMessageLoop()); + service.StartFetchModel(); + msg_loop_.Run(); // EndFetchModel will quit the message loop. + Mock::VerifyAndClearExpectations(&service); + + // Unable to parse the model file. + SetModelFetchResponse("Invalid model file", true /* success */); + EXPECT_CALL(service, EndFetchModel( + ClientSideDetectionService::MODEL_PARSE_ERROR)) + .WillOnce(QuitCurrentMessageLoop()); + service.StartFetchModel(); + msg_loop_.Run(); // EndFetchModel will quit the message loop. + Mock::VerifyAndClearExpectations(&service); + + // Model that is missing some required fields (missing the version field). + ClientSideModel model; + model.set_max_words_per_term(4); + SetModelFetchResponse(model.SerializePartialAsString(), true /* success */); + EXPECT_CALL(service, EndFetchModel( + ClientSideDetectionService::MODEL_MISSING_FIELDS)) + .WillOnce(QuitCurrentMessageLoop()); + service.StartFetchModel(); + msg_loop_.Run(); // EndFetchModel will quit the message loop. + Mock::VerifyAndClearExpectations(&service); + + // Model version number is wrong. + model.set_version(-1); + SetModelFetchResponse(model.SerializeAsString(), true /* success */); + EXPECT_CALL(service, EndFetchModel( + ClientSideDetectionService::MODEL_INVALID_VERSION_NUMBER)) + .WillOnce(QuitCurrentMessageLoop()); + service.StartFetchModel(); + msg_loop_.Run(); // EndFetchModel will quit the message loop. + Mock::VerifyAndClearExpectations(&service); + + // Normal model. + model.set_version(10); + SetModelFetchResponse(model.SerializeAsString(), true /* success */); + EXPECT_CALL(service, EndFetchModel( + ClientSideDetectionService::MODEL_SUCCESS)) + .WillOnce(QuitCurrentMessageLoop()); + service.StartFetchModel(); + msg_loop_.Run(); // EndFetchModel will quit the message loop. + Mock::VerifyAndClearExpectations(&service); + + // Model version number is decreasing. + service.model_version_ = 11; + SetModelFetchResponse(model.SerializeAsString(), true /* success */); + EXPECT_CALL(service, EndFetchModel( + ClientSideDetectionService::MODEL_INVALID_VERSION_NUMBER)) + .WillOnce(QuitCurrentMessageLoop()); + service.StartFetchModel(); + msg_loop_.Run(); // EndFetchModel will quit the message loop. + Mock::VerifyAndClearExpectations(&service); + + // Model version hasn't changed since the last reload. + service.model_version_ = 10; + SetModelFetchResponse(model.SerializeAsString(), true /* success */); + EXPECT_CALL(service, EndFetchModel( + ClientSideDetectionService::MODEL_NOT_CHANGED)) + .WillOnce(QuitCurrentMessageLoop()); + service.StartFetchModel(); + msg_loop_.Run(); // EndFetchModel will quit the message loop. + Mock::VerifyAndClearExpectations(&service); +} + TEST_F(ClientSideDetectionServiceTest, ServiceObjectDeletedBeforeCallbackDone) { SetModelFetchResponse("bogus model", true /* success */); ScopedTempDir tmp_dir; ASSERT_TRUE(tmp_dir.CreateUniqueTempDir()); - csd_service_.reset(ClientSideDetectionService::Create( - tmp_dir.path().AppendASCII("model"), NULL)); + csd_service_.reset(ClientSideDetectionService::Create(tmp_dir.path(), NULL)); EXPECT_TRUE(csd_service_.get() != NULL); // We delete the client-side detection service class even though the callbacks // haven't run yet. @@ -178,8 +283,7 @@ TEST_F(ClientSideDetectionServiceTest, SendClientReportPhishingRequest) { SetModelFetchResponse("bogus model", true /* success */); ScopedTempDir tmp_dir; ASSERT_TRUE(tmp_dir.CreateUniqueTempDir()); - csd_service_.reset(ClientSideDetectionService::Create( - tmp_dir.path().AppendASCII("model"), NULL)); + csd_service_.reset(ClientSideDetectionService::Create(tmp_dir.path(), NULL)); GURL url("http://a.com/"); float score = 0.4f; // Some random client score. @@ -228,8 +332,7 @@ TEST_F(ClientSideDetectionServiceTest, GetNumReportTest) { SetModelFetchResponse("bogus model", true /* success */); ScopedTempDir tmp_dir; ASSERT_TRUE(tmp_dir.CreateUniqueTempDir()); - csd_service_.reset(ClientSideDetectionService::Create( - tmp_dir.path().AppendASCII("model"), NULL)); + csd_service_.reset(ClientSideDetectionService::Create(tmp_dir.path(), NULL)); std::queue<base::Time>& report_times = GetPhishingReportTimes(); base::Time now = base::Time::Now(); @@ -246,8 +349,7 @@ TEST_F(ClientSideDetectionServiceTest, CacheTest) { SetModelFetchResponse("bogus model", true /* success */); ScopedTempDir tmp_dir; ASSERT_TRUE(tmp_dir.CreateUniqueTempDir()); - csd_service_.reset(ClientSideDetectionService::Create( - tmp_dir.path().AppendASCII("model"), NULL)); + csd_service_.reset(ClientSideDetectionService::Create(tmp_dir.path(), NULL)); TestCache(); } @@ -256,8 +358,7 @@ TEST_F(ClientSideDetectionServiceTest, IsPrivateIPAddress) { SetModelFetchResponse("bogus model", true /* success */); ScopedTempDir tmp_dir; ASSERT_TRUE(tmp_dir.CreateUniqueTempDir()); - csd_service_.reset(ClientSideDetectionService::Create( - tmp_dir.path().AppendASCII("model"), NULL)); + csd_service_.reset(ClientSideDetectionService::Create(tmp_dir.path(), NULL)); EXPECT_TRUE(csd_service_->IsPrivateIPAddress("10.1.2.3")); EXPECT_TRUE(csd_service_->IsPrivateIPAddress("127.0.0.1")); |