diff options
-rw-r--r-- | chrome/browser/safe_browsing/local_database_manager.cc | 31 | ||||
-rw-r--r-- | chrome/browser/safe_browsing/local_database_manager_unittest.cc | 128 |
2 files changed, 108 insertions, 51 deletions
diff --git a/chrome/browser/safe_browsing/local_database_manager.cc b/chrome/browser/safe_browsing/local_database_manager.cc index a4153e6..19e52ca 100644 --- a/chrome/browser/safe_browsing/local_database_manager.cc +++ b/chrome/browser/safe_browsing/local_database_manager.cc @@ -1076,27 +1076,36 @@ bool LocalSafeBrowsingDatabaseManager::HandleOneCheck( // which are called from this code. Refactoring that across the checks could // interact well with batching the checks here. - // TODO(gab): Fix the fact that Get(Url|Hash)SeverestThreatType() may return a - // threat for which IsExpectedThreat() returns false even if |full_hashes| - // actually contains an expected threat. + std::vector<SBFullHashResult> expected_full_hashes; + for (const auto& full_hash : full_hashes) { + ListType type = static_cast<ListType>(full_hash.list_id); + if (IsExpectedThreat(GetThreatTypeFromListType(type), + check->expected_threats)) { + expected_full_hashes.push_back(full_hash); + } + } + + if (expected_full_hashes.empty()) { + SafeBrowsingCheckDone(check); + return false; + } for (size_t i = 0; i < check->urls.size(); ++i) { size_t threat_index; - SBThreatType threat = - GetUrlSeverestThreatType(check->urls[i], full_hashes, &threat_index); - if (threat != SB_THREAT_TYPE_SAFE && - IsExpectedThreat(threat, check->expected_threats)) { + SBThreatType threat = GetUrlSeverestThreatType(check->urls[i], + expected_full_hashes, + &threat_index); + if (threat != SB_THREAT_TYPE_SAFE) { check->url_results[i] = threat; - check->url_metadata[i] = full_hashes[threat_index].metadata; + check->url_metadata[i] = expected_full_hashes[threat_index].metadata; is_threat = true; } } for (size_t i = 0; i < check->full_hashes.size(); ++i) { SBThreatType threat = - GetHashSeverestThreatType(check->full_hashes[i], full_hashes); - if (threat != SB_THREAT_TYPE_SAFE && - IsExpectedThreat(threat, check->expected_threats)) { + GetHashSeverestThreatType(check->full_hashes[i], expected_full_hashes); + if (threat != SB_THREAT_TYPE_SAFE) { check->full_hash_results[i] = threat; is_threat = true; } diff --git a/chrome/browser/safe_browsing/local_database_manager_unittest.cc b/chrome/browser/safe_browsing/local_database_manager_unittest.cc index e4d6861..5ea449d 100644 --- a/chrome/browser/safe_browsing/local_database_manager_unittest.cc +++ b/chrome/browser/safe_browsing/local_database_manager_unittest.cc @@ -23,74 +23,122 @@ using content::TestBrowserThreadBundle; namespace safe_browsing { -namespace { - -class TestClient : public SafeBrowsingDatabaseManager::Client { - public: - TestClient() {} - ~TestClient() override {} - - void OnCheckBrowseUrlResult(const GURL& url, - SBThreatType threat_type, - const std::string& metadata) override {} - - void OnCheckDownloadUrlResult(const std::vector<GURL>& url_chain, - SBThreatType threat_type) override {} - - private: - DISALLOW_COPY_AND_ASSIGN(TestClient); -}; - -} // namespace - class SafeBrowsingDatabaseManagerTest : public PlatformTest { public: + struct HostListPair { + std::string host; + std::string list_type; + }; + bool RunSBHashTest(const ListType list_type, const std::vector<SBThreatType>& expected_threats, - const std::string& result_list); + const std::vector<std::string>& result_lists); + bool RunUrlTest( + const GURL& url, ListType list_type, + const std::vector<SBThreatType>& expected_threats, + const std::vector<HostListPair>& host_list_results); private: + bool RunTest(LocalSafeBrowsingDatabaseManager::SafeBrowsingCheck* check, + const std::vector<SBFullHashResult>& hash_results); + TestBrowserThreadBundle thread_bundle_; }; bool SafeBrowsingDatabaseManagerTest::RunSBHashTest( const ListType list_type, const std::vector<SBThreatType>& expected_threats, - const std::string& result_list) { + const std::vector<std::string>& result_lists) { + const SBFullHash same_full_hash = {}; + scoped_ptr<LocalSafeBrowsingDatabaseManager::SafeBrowsingCheck> check( + new LocalSafeBrowsingDatabaseManager::SafeBrowsingCheck( + std::vector<GURL>(), std::vector<SBFullHash>(1, same_full_hash), NULL, + list_type, expected_threats)); + + std::vector<SBFullHashResult> fake_results; + for (const auto& result_list : result_lists) { + const SBFullHashResult full_hash_result = {same_full_hash, + GetListId(result_list)}; + fake_results.push_back(full_hash_result); + } + return RunTest(check.get(), fake_results); +} + +bool SafeBrowsingDatabaseManagerTest::RunUrlTest( + const GURL& url, ListType list_type, + const std::vector<SBThreatType>& expected_threats, + const std::vector<HostListPair>& host_list_results) { + scoped_ptr<LocalSafeBrowsingDatabaseManager::SafeBrowsingCheck> check( + new LocalSafeBrowsingDatabaseManager::SafeBrowsingCheck( + std::vector<GURL>(1, url), std::vector<SBFullHash>(), NULL, + list_type, expected_threats)); + std::vector<SBFullHashResult> full_hash_results; + for (const auto& host_list : host_list_results) { + SBFullHashResult hash_result = + {SBFullHashForString(host_list.host), GetListId(host_list.list_type)}; + full_hash_results.push_back(hash_result); + } + return RunTest(check.get(), full_hash_results); +} + +bool SafeBrowsingDatabaseManagerTest::RunTest( + LocalSafeBrowsingDatabaseManager::SafeBrowsingCheck* check, + const std::vector<SBFullHashResult>& hash_results) { scoped_refptr<SafeBrowsingService> sb_service_( SafeBrowsingService::CreateSafeBrowsingService()); scoped_refptr<LocalSafeBrowsingDatabaseManager> db_manager_( new LocalSafeBrowsingDatabaseManager(sb_service_)); - const SBFullHash same_full_hash = {}; - - LocalSafeBrowsingDatabaseManager::SafeBrowsingCheck* check = - new LocalSafeBrowsingDatabaseManager::SafeBrowsingCheck( - std::vector<GURL>(), std::vector<SBFullHash>(1, same_full_hash), NULL, - list_type, expected_threats); db_manager_->checks_.insert(check); - const SBFullHashResult full_hash_result = {same_full_hash, - GetListId(result_list)}; - - std::vector<SBFullHashResult> fake_results(1, full_hash_result); - bool result = db_manager_->HandleOneCheck(check, fake_results); + bool result = db_manager_->HandleOneCheck(check, hash_results); db_manager_->checks_.erase(check); - delete check; return result; } -TEST_F(SafeBrowsingDatabaseManagerTest, CheckCorrespondsListType) { +TEST_F(SafeBrowsingDatabaseManagerTest, CheckCorrespondsListTypeForHash) { std::vector<SBThreatType> malware_threat(1, SB_THREAT_TYPE_BINARY_MALWARE_URL); - EXPECT_FALSE(RunSBHashTest(BINURL, malware_threat, kMalwareList)); - EXPECT_TRUE(RunSBHashTest(BINURL, malware_threat, kBinUrlList)); + EXPECT_FALSE(RunSBHashTest(BINURL, malware_threat, {kMalwareList})); + EXPECT_TRUE(RunSBHashTest(BINURL, malware_threat, {kBinUrlList})); // Check for multiple threats std::vector<SBThreatType> multiple_threats; multiple_threats.push_back(SB_THREAT_TYPE_URL_MALWARE); multiple_threats.push_back(SB_THREAT_TYPE_URL_PHISHING); - EXPECT_FALSE(RunSBHashTest(MALWARE, multiple_threats, kBinUrlList)); - EXPECT_TRUE(RunSBHashTest(MALWARE, multiple_threats, kMalwareList)); + EXPECT_FALSE(RunSBHashTest(MALWARE, multiple_threats, {kBinUrlList})); + EXPECT_TRUE(RunSBHashTest(MALWARE, multiple_threats, {kMalwareList})); + + // Check for multiple hash hits + std::vector<SBThreatType> unwanted_threat = {SB_THREAT_TYPE_URL_UNWANTED}; + std::vector<std::string> hash_hits = {kMalwareList, kUnwantedUrlList}; + EXPECT_TRUE(RunSBHashTest(UNWANTEDURL, unwanted_threat, hash_hits)); +} + +TEST_F(SafeBrowsingDatabaseManagerTest, CheckCorrespondsListTypeForUrl) { + const GURL url("http://www.host.com/index.html"); + const std::string host1 = "host.com/"; + const std::string host2 = "www.host.com/"; + const std::vector<HostListPair> malware_list_result = + {{host1, kMalwareList}}; + const std::vector<HostListPair> binurl_list_result = + {{host2, kBinUrlList}}; + + std::vector<SBThreatType> malware_threat = + {SB_THREAT_TYPE_BINARY_MALWARE_URL}; + EXPECT_FALSE(RunUrlTest(url, BINURL, malware_threat, malware_list_result)); + EXPECT_TRUE(RunUrlTest(url, BINURL, malware_threat, binurl_list_result)); + + // Check for multiple expected threats + std::vector<SBThreatType> multiple_threats = + {SB_THREAT_TYPE_URL_MALWARE, SB_THREAT_TYPE_URL_PHISHING}; + EXPECT_FALSE(RunUrlTest(url, MALWARE, multiple_threats, binurl_list_result)); + EXPECT_TRUE(RunUrlTest(url, MALWARE, multiple_threats, malware_list_result)); + + // Check for multiple database hits + std::vector<SBThreatType> unwanted_threat = {SB_THREAT_TYPE_URL_UNWANTED}; + std::vector<HostListPair> multiple_results = { + {host1, kMalwareList}, {host2, kUnwantedUrlList}}; + EXPECT_TRUE(RunUrlTest(url, UNWANTEDURL, unwanted_threat, multiple_results)); } TEST_F(SafeBrowsingDatabaseManagerTest, GetUrlSeverestThreatType) { @@ -199,7 +247,7 @@ TEST_F(SafeBrowsingDatabaseManagerTest, ServiceStopWithPendingChecks) { SafeBrowsingService::CreateSafeBrowsingService()); scoped_refptr<LocalSafeBrowsingDatabaseManager> db_manager( new LocalSafeBrowsingDatabaseManager(sb_service)); - TestClient client; + SafeBrowsingDatabaseManager::Client client; // Start the service and flush tasks to ensure database is made available. db_manager->StartOnIOThread(); |